Skip to content

Commit

Permalink
Try llvm.jl 9.1 (#1857)
Browse files Browse the repository at this point in the history
* Try llvm.jl 9.1

* fixups

* more fix

* bump version

* fix

* fix

* fix
  • Loading branch information
wsmoses authored Sep 18, 2024
1 parent f49d1fc commit 6e867ba
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 31 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.13.0"
version = "0.13.1"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand Down Expand Up @@ -37,7 +37,7 @@ ChainRulesCore = "1"
EnzymeCore = "0.8"
Enzyme_jll = "0.0.150"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27"
LLVM = "6.1, 7, 8, =9.0"
LLVM = "6.1, 7, 8, 9"
LogExpFunctions = "0.3"
ObjectFile = "0.4"
Preferences = "1.4"
Expand Down
46 changes: 31 additions & 15 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4045,6 +4045,22 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
return adjointf, augmented_primalf, TapeType
end

function get_subprogram(f::LLVM.Function)
@static if isdefined(LLVM, :subprogram)
LLVM.subprogram(f)
else
LLVM.get_subprogram(f)
end
end

function set_subprogram!(f::LLVM.Function, sp)
@static if isdefined(LLVM, :subprogram)
LLVM.subprogram!(f, sp)
else
LLVM.set_subprogram!(f, sp)
end
end

function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, Mode::API.CDerivativeMode, augmented, width, returnPrimal, shadow_init, world, interp)
is_adjoint = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModeCombined
is_split = Mode == API.DEM_ReverseModeGradient || Mode == API.DEM_ReverseModePrimal
Expand Down Expand Up @@ -4422,8 +4438,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
push!(args, psret)
end
res = LLVM.call!(builder, LLVM.function_type(llvmf), llvmf, args)
if LLVM.get_subprogram(llvmf) !== nothing
metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) )
if get_subprogram(llvmf) !== nothing
metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) )
end
if psret !== nothing
res = load!(builder, convert(LLVMType, Func_RT), psret)
Expand All @@ -4449,8 +4465,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
end

val = call!(builder, LLVM.function_type(enzymefn), enzymefn, realparms)
if LLVM.get_subprogram(llvm_f) !== nothing
metadata(val)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) )
if get_subprogram(llvm_f) !== nothing
metadata(val)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) )
end

@inline function fixup_abi(index, value)
Expand Down Expand Up @@ -4514,8 +4530,8 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
push!(function_attributes(cf), EnumAttribute("alwaysinline", 0))
for shadowv in shadows
c = call!(builder, LLVM.function_type(cf), cf, [shadowv])
if LLVM.get_subprogram(llvm_f) !== nothing
metadata(c)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(llvm_f) )
if get_subprogram(llvm_f) !== nothing
metadata(c)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(llvm_f) )
end
end
end
Expand Down Expand Up @@ -5027,9 +5043,9 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
wrapper_ft = LLVM.FunctionType(RT, wrapper_types)
wrapper_f = LLVM.Function(mod, LLVM.name(entry_f), wrapper_ft)
callconv!(wrapper_f, callconv(entry_f))
sfn = LLVM.get_subprogram(entry_f)
sfn = get_subprogram(entry_f)
if sfn !== nothing
LLVM.set_subprogram!(wrapper_f, sfn)
set_subprogram!(wrapper_f, sfn)
end

hasReturnsTwice = any(map(k->kind(k)==kind(EnumAttribute("returns_twice")), collect(function_attributes(entry_f))))
Expand Down Expand Up @@ -5107,8 +5123,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function

entry = BasicBlock(wrapper_f, "entry")
position!(builder, entry)
if LLVM.get_subprogram(entry_f) !== nothing
debuglocation!(builder, DILocation(0, 0, LLVM.get_subprogram(entry_f)))
if get_subprogram(entry_f) !== nothing
debuglocation!(builder, DILocation(0, 0, get_subprogram(entry_f)))
end

wrapper_args = Vector{LLVM.Value}()
Expand Down Expand Up @@ -5178,8 +5194,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
end
res = call!(builder, LLVM.function_type(entry_f), entry_f, wrapper_args)

if LLVM.get_subprogram(entry_f) !== nothing
metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, LLVM.get_subprogram(entry_f) )
if get_subprogram(entry_f) !== nothing
metadata(res)[LLVM.MD_dbg] = DILocation( 0, 0, get_subprogram(entry_f) )
end

callconv!(res, LLVM.callconv(entry_f))
Expand Down Expand Up @@ -5411,10 +5427,10 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
LLVM.run!(pm, mod)
end
if haskey(globals(mod), "llvm.used")
unsafe_delete!(mod, globals(mod)["llvm.used"])
eraseInst(mod, globals(mod)["llvm.used"])
for u in user.(collect(uses(entry_f)))
if isa(u, LLVM.GlobalVariable) && endswith(LLVM.name(u), "_slot") && startswith(LLVM.name(u), "julia")
unsafe_delete!(mod, u)
eraseInst(mod, u)
end
end
end
Expand Down Expand Up @@ -6469,7 +6485,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
st = LLVM.user(u)
LLVM.API.LLVMInstructionEraseFromParent(st)
end
LLVM.unsafe_delete!(mod, f)
eraseInst(mod, f)
end

linkage!(adjointf, LLVM.API.LLVMExternalLinkage)
Expand Down
20 changes: 10 additions & 10 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ function memcpy_alloca_to_loadstore(mod::LLVM.Module)
end
end
for inst in todel
unsafe_delete!(LLVM.parent(inst), inst)
eraseInst(LLVM.parent(inst), inst)
end
end
end
Expand Down Expand Up @@ -1145,7 +1145,7 @@ function prop_global!(g)
end
end
replace_uses!(var, res)
unsafe_delete!(LLVM.parent(var), var)
eraseInst(LLVM.parent(var), var)
continue
end
if isa(var, LLVM.AddrSpaceCastInst)
Expand Down Expand Up @@ -1441,7 +1441,7 @@ function propagate_returned!(mod::LLVM.Module)
end
if !illegalUse
for c in reverse(torem)
unsafe_delete!(LLVM.parent(c), c)
eraseInst(LLVM.parent(c), c)
end
B = IRBuilder()
position!(B, first(instructions(first(blocks(fn)))))
Expand Down Expand Up @@ -1617,7 +1617,7 @@ function propagate_returned!(mod::LLVM.Module)
end
API.EnzymeSetCalledFunction(un, nfn, toremove)
end
unsafe_delete!(mod, fn)
eraseInst(mod, fn)
changed = true
catch
break
Expand Down Expand Up @@ -2030,26 +2030,26 @@ function removeDeadArgs!(mod::LLVM.Module, tm)

for u in LLVM.uses(rfunc)
u = LLVM.user(u)
unsafe_delete!(LLVM.parent(u), u)
eraseInst(LLVM.parent(u), u)
end
unsafe_delete!(mod, rfunc)
eraseInst(mod, rfunc)
for u in LLVM.uses(sfunc)
u = LLVM.user(u)
unsafe_delete!(LLVM.parent(u), u)
eraseInst(LLVM.parent(u), u)
end
unsafe_delete!(mod, sfunc)
eraseInst(mod, sfunc)
for fn in functions(mod)
for b in blocks(fn)
inst = first(LLVM.instructions(b))
if isa(inst, LLVM.CallInst)
fn = LLVM.called_operand(inst)
if fn == func
unsafe_delete!(b, inst)
eraseInst(b, inst)
end
end
end
end
unsafe_delete!(mod, func)
eraseInst(mod, func)
end

function optimize!(mod::LLVM.Module, tm)
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/orcv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ function get_trampoline(job)
# but it would be nicer if _thunk just codegen'd the half
# we need.
other_func = functions(mod)[other_name]
LLVM.unsafe_delete!(mod, other_func)
Compiler.eraseInst(mod, other_func)
end

tsm = move_to_threadsafe(mod)
Expand Down
10 changes: 9 additions & 1 deletion src/compiler/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,14 @@ function reinsert_gcmarker!(func, PB=nothing)
end
end

function eraseInst(bb, inst)
@static if isdefined(LLVM, Symbol("erase!"))
LLVM.erase!(inst)
else
unsafe_delete!(bb, inst)
end
end

function unique_gcmarker!(func)
entry_bb = first(blocks(func))
pgcstack_func = declare_pgcstack!(LLVM.parent(func))
Expand All @@ -327,7 +335,7 @@ function unique_gcmarker!(func)
for i in 2:length(found)
LLVM.replace_uses!(found[i], found[1])
ops = LLVM.collect(operands(found[i]))
Base.unsafe_delete!(entry_bb, found[i])
eraseInst(entry_bb, found[i])
end
end
return nothing
Expand Down
4 changes: 2 additions & 2 deletions src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ function restore_lookups(mod::LLVM.Module)
if haskey(functions(mod), k)
f = functions(mod)[k]
replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstIntToPtr(ConstantInt(T_size_t, convert(UInt, v)), value_type(f))))
unsafe_delete!(mod, f)
eraseInst(mod, f)
end
end
end
Expand Down Expand Up @@ -272,7 +272,7 @@ function check_ir!(job, errors, mod::LLVM.Module)

mfn = LLVM.API.LLVMAddFunction(mod, "malloc", LLVM.FunctionType(ptr8, parameters(prev_ft)))
replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstPointerCast(mfn, value_type(f))))
unsafe_delete!(mod, f)
eraseInst(mod, f)
end
rewrite_ccalls!(mod)
for f in collect(functions(mod))
Expand Down

2 comments on commit 6e867ba

@wsmoses
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/115456

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.13.1 -m "<description of version>" 6e867ba81bab2abafaed85f56a0f6e7cc38b01a2
git push origin v0.13.1

Please sign in to comment.