From de2b79985a236e4127bf1fe8e7c4adb7385db7ef Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 31 May 2024 11:35:17 -0400 Subject: [PATCH 1/2] llvm.julia.gc_preserve_begin splatting --- src/compiler/validation.jl | 44 +++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index f8fa3a4cd2..6a66aef8c4 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -167,10 +167,52 @@ function rewrite_ccalls!(mod::LLVM.Module) replaceAndErase = Tuple{Instruction, Instruction}[] for bb in blocks(f), inst in instructions(bb) if isa(inst, LLVM.CallInst) + fn = called_operand(inst) changed = false - newbundles = OperandBundleDef[] B = IRBuilder() position!(B, inst) + if isa(fn, LLVM.Function) && LLVM.name(fn) == "llvm.julia.gc_preserve_begin" + uservals = LLVM.Value[] + for lval in collect(arguments(inst)) + llty = value_type(lval) + if !isa(llty, LLVM.PointerType) || LLVM.addrspace(llty) != 10 + push!(uservals, lval) + continue + end + vals = get_julia_inner_types(B, nothing, lval) + for v in vals + if isa(v, LLVM.PointerNull) + subchanged = true + continue + end + push!(uservals, v) + end + if length(vals) == 1 && vals[1] == lval + continue + end + changed = true + end + if changed + prevname = LLVM.name(inst) + LLVM.name!(inst, "") + newinst = call!(B, called_type(inst), called_operand(inst), uservals, collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), prevname) + for idx = [LLVM.API.LLVMAttributeFunctionIndex, LLVM.API.LLVMAttributeReturnIndex, [LLVM.API.LLVMAttributeIndex(i) for i in 1:(length(arguments(inst)))]...] + idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx); + Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count)) + LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) + for j in 1:count + LLVM.API.LLVMAddCallSiteAttribute(newinst, idx, unsafe_load(Attrs, j)) + end + Libc.free(Attrs) + end + API.EnzymeCopyMetadata(newinst, inst) + callconv!(newinst, callconv(inst)) + push!(replaceAndErase, (inst, newinst)) + end + continue + end + newbundles = OperandBundleDef[] for bunduse in operand_bundles(inst) bunduse = LLVM.OperandBundleDef(bunduse) if LLVM.tag_name(bunduse) != "jl_roots" From eb213d649d37b9070749478c9f98997ccec54b85 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 31 May 2024 12:24:36 -0400 Subject: [PATCH 2/2] fix --- src/compiler/validation.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 6a66aef8c4..68eb4a5bca 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -175,7 +175,7 @@ function rewrite_ccalls!(mod::LLVM.Module) uservals = LLVM.Value[] for lval in collect(arguments(inst)) llty = value_type(lval) - if !isa(llty, LLVM.PointerType) || LLVM.addrspace(llty) != 10 + if isa(llty, LLVM.PointerType) push!(uservals, lval) continue end @@ -223,7 +223,7 @@ function rewrite_ccalls!(mod::LLVM.Module) subchanged = false for lval in LLVM.inputs(bunduse) llty = value_type(lval) - if !isa(llty, LLVM.PointerType) || LLVM.addrspace(llty) != 10 + if isa(llty, LLVM.PointerType) push!(uservals, lval) continue end