From 6e867ba81bab2abafaed85f56a0f6e7cc38b01a2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 18 Sep 2024 16:46:45 -0500 Subject: [PATCH] Try llvm.jl 9.1 (#1857) * Try llvm.jl 9.1 * fixups * more fix * bump version * fix * fix * fix --- Project.toml | 4 ++-- src/compiler.jl | 46 +++++++++++++++++++++++++------------- src/compiler/optimize.jl | 20 ++++++++--------- src/compiler/orcv2.jl | 2 +- src/compiler/utils.jl | 10 ++++++++- src/compiler/validation.jl | 4 ++-- 6 files changed, 55 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index 19315d01dc..9cf8028a76 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.13.0" +version = "0.13.1" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -37,7 +37,7 @@ ChainRulesCore = "1" EnzymeCore = "0.8" Enzyme_jll = "0.0.150" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27" -LLVM = "6.1, 7, 8, =9.0" +LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" ObjectFile = "0.4" Preferences = "1.4" diff --git a/src/compiler.jl b/src/compiler.jl index 1d21fb99a1..be4679d263 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -4045,6 +4045,22 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr return adjointf, augmented_primalf, TapeType end +function get_subprogram(f::LLVM.Function) + @static if isdefined(LLVM, :subprogram) + LLVM.subprogram(f) + else + LLVM.get_subprogram(f) + end +end + +function set_subprogram!(f::LLVM.Function, sp) + @static if isdefined(LLVM, :subprogram) + LLVM.subprogram!(f, sp) + else + LLVM.set_subprogram!(f, sp) + end +end + function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, Mode::API.CDerivativeMode, augmented, width, returnPrimal, shadow_init, world, interp) is_adjoint = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModeCombined is_split = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModePrimal @@ -4422,8 +4438,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, push!(args, psret) end res = LLVM.call!(builder, LLVM.function_type(llvmf), llvmf, args) - if LLVM.get_subprogram(llvmf) !== nothing - metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) ) + if get_subprogram(llvmf) !== nothing + metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) ) end if psret !== nothing res = load!(builder, convert(LLVMType, Func_RT), psret) @@ -4449,8 +4465,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, end val = call!(builder, LLVM.function_type(enzymefn), enzymefn, realparms) - if LLVM.get_subprogram(llvm_f) !== nothing - metadata(val)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) ) + if get_subprogram(llvm_f) !== nothing + metadata(val)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) ) end @inline function fixup_abi(index, value) @@ -4514,8 +4530,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, push!(function_attributes(cf), EnumAttribute("alwaysinline", 0)) for shadowv in shadows c = call!(builder, LLVM.function_type(cf), cf, [shadowv]) - if LLVM.get_subprogram(llvm_f) !== nothing - metadata(c)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) ) + if get_subprogram(llvm_f) !== nothing + metadata(c)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) ) end end end @@ -5027,9 +5043,9 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function wrapper_ft = LLVM.FunctionType(RT, wrapper_types) wrapper_f = LLVM.Function(mod, LLVM.name(entry_f), wrapper_ft) callconv!(wrapper_f, callconv(entry_f)) - sfn = LLVM.get_subprogram(entry_f) + sfn = get_subprogram(entry_f) if sfn !== nothing - LLVM.set_subprogram!(wrapper_f, sfn) + set_subprogram!(wrapper_f, sfn) end hasReturnsTwice = any(map(k->kind(k)==kind(EnumAttribute("returns_twice")), collect(function_attributes(entry_f)))) @@ -5107,8 +5123,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function entry = BasicBlock(wrapper_f, "entry") position!(builder, entry) - if LLVM.get_subprogram(entry_f) !== nothing - debuglocation!(builder, DILocation(0, 0, LLVM.get_subprogram(entry_f))) + if get_subprogram(entry_f) !== nothing + debuglocation!(builder, DILocation(0, 0, get_subprogram(entry_f))) end wrapper_args = Vector{LLVM.Value}() @@ -5178,8 +5194,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function end res = call!(builder, LLVM.function_type(entry_f), entry_f, wrapper_args) - if LLVM.get_subprogram(entry_f) !== nothing - metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(entry_f) ) + if get_subprogram(entry_f) !== nothing + metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(entry_f) ) end callconv!(res, LLVM.callconv(entry_f)) @@ -5411,10 +5427,10 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function LLVM.run!(pm, mod) end if haskey(globals(mod), "llvm.used") - unsafe_delete!(mod, globals(mod)["llvm.used"]) + eraseInst(mod, globals(mod)["llvm.used"]) for u in user.(collect(uses(entry_f))) if isa(u, LLVM.GlobalVariable) && endswith(LLVM.name(u), "_slot") && startswith(LLVM.name(u), "julia") - unsafe_delete!(mod, u) + eraseInst(mod, u) end end end @@ -6469,7 +6485,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; st = LLVM.user(u) LLVM.API.LLVMInstructionEraseFromParent(st) end - LLVM.unsafe_delete!(mod, f) + eraseInst(mod, f) end linkage!(adjointf, LLVM.API.LLVMExternalLinkage) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 8c6385edb8..2e3e8194c9 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -533,7 +533,7 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module) end end for inst in todel - unsafe_delete!(LLVM.parent(inst), inst) + eraseInst(LLVM.parent(inst), inst) end end end @@ -1145,7 +1145,7 @@ function prop_global!(g) end end replace_uses!(var, res) - unsafe_delete!(LLVM.parent(var), var) + eraseInst(LLVM.parent(var), var) continue end if isa(var, LLVM.AddrSpaceCastInst) @@ -1441,7 +1441,7 @@ function propagate_returned!(mod::LLVM.Module) end if !illegalUse for c in reverse(torem) - unsafe_delete!(LLVM.parent(c), c) + eraseInst(LLVM.parent(c), c) end B = IRBuilder() position!(B, first(instructions(first(blocks(fn))))) @@ -1617,7 +1617,7 @@ function propagate_returned!(mod::LLVM.Module) end API.EnzymeSetCalledFunction(un, nfn, toremove) end - unsafe_delete!(mod, fn) + eraseInst(mod, fn) changed = true catch break @@ -2030,26 +2030,26 @@ function removeDeadArgs!(mod::LLVM.Module, tm) for u in LLVM.uses(rfunc) u = LLVM.user(u) - unsafe_delete!(LLVM.parent(u), u) + eraseInst(LLVM.parent(u), u) end - unsafe_delete!(mod, rfunc) + eraseInst(mod, rfunc) for u in LLVM.uses(sfunc) u = LLVM.user(u) - unsafe_delete!(LLVM.parent(u), u) + eraseInst(LLVM.parent(u), u) end - unsafe_delete!(mod, sfunc) + eraseInst(mod, sfunc) for fn in functions(mod) for b in blocks(fn) inst = first(LLVM.instructions(b)) if isa(inst, LLVM.CallInst) fn = LLVM.called_operand(inst) if fn == func - unsafe_delete!(b, inst) + eraseInst(b, inst) end end end end - unsafe_delete!(mod, func) + eraseInst(mod, func) end function optimize!(mod::LLVM.Module, tm) diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index 40d13eea80..78ff089e7d 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -224,7 +224,7 @@ function get_trampoline(job) # but it would be nicer if _thunk just codegen'd the half # we need. other_func = functions(mod)[other_name] - LLVM.unsafe_delete!(mod, other_func) + Compiler.eraseInst(mod, other_func) end tsm = move_to_threadsafe(mod) diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index 6615b6bd40..cde5d2cade 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -313,6 +313,14 @@ function reinsert_gcmarker!(func, PB=nothing) end end +function eraseInst(bb, inst) + @static if isdefined(LLVM, Symbol("erase!")) + LLVM.erase!(inst) + else + unsafe_delete!(bb, inst) + end +end + function unique_gcmarker!(func) entry_bb = first(blocks(func)) pgcstack_func = declare_pgcstack!(LLVM.parent(func)) @@ -327,7 +335,7 @@ function unique_gcmarker!(func) for i in 2:length(found) LLVM.replace_uses!(found[i], found[1]) ops = LLVM.collect(operands(found[i])) - Base.unsafe_delete!(entry_bb, found[i]) + eraseInst(entry_bb, found[i]) end end return nothing diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 51aeacf675..3df37be117 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -112,7 +112,7 @@ function restore_lookups(mod::LLVM.Module) if haskey(functions(mod), k) f = functions(mod)[k] replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstIntToPtr(ConstantInt(T_size_t, convert(UInt, v)), value_type(f)))) - unsafe_delete!(mod, f) + eraseInst(mod, f) end end end @@ -272,7 +272,7 @@ function check_ir!(job, errors, mod::LLVM.Module) mfn = LLVM.API.LLVMAddFunction(mod, "malloc", LLVM.FunctionType(ptr8, parameters(prev_ft))) replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstPointerCast(mfn, value_type(f)))) - unsafe_delete!(mod, f) + eraseInst(mod, f) end rewrite_ccalls!(mod) for f in collect(functions(mod))