Skip to content

Commit

Permalink
llvm.julia.gc_preserve_begin splatting (#1486)
Browse files Browse the repository at this point in the history
* llvm.julia.gc_preserve_begin splatting

* fix
  • Loading branch information
wsmoses authored Jun 2, 2024
1 parent 7526a5c commit 21b0762
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,52 @@ function rewrite_ccalls!(mod::LLVM.Module)
replaceAndErase = Tuple{Instruction, Instruction}[]
for bb in blocks(f), inst in instructions(bb)
if isa(inst, LLVM.CallInst)
fn = called_operand(inst)
changed = false
newbundles = OperandBundleDef[]
B = IRBuilder()
position!(B, inst)
if isa(fn, LLVM.Function) && LLVM.name(fn) == "llvm.julia.gc_preserve_begin"
uservals = LLVM.Value[]
for lval in collect(arguments(inst))
llty = value_type(lval)
if isa(llty, LLVM.PointerType)
push!(uservals, lval)
continue
end
vals = get_julia_inner_types(B, nothing, lval)
for v in vals
if isa(v, LLVM.PointerNull)
subchanged = true
continue
end
push!(uservals, v)
end
if length(vals) == 1 && vals[1] == lval
continue
end
changed = true
end
if changed
prevname = LLVM.name(inst)
LLVM.name!(inst, "")
newinst = call!(B, called_type(inst), called_operand(inst), uservals, collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), prevname)
for idx = [LLVM.API.LLVMAttributeFunctionIndex, LLVM.API.LLVMAttributeReturnIndex, [LLVM.API.LLVMAttributeIndex(i) for i in 1:(length(arguments(inst)))]...]
idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx)
count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx);
Attrs = Base.unsafe_convert(Ptr{LLVM.API.LLVMAttributeRef}, Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef)*count))
LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs)
for j in 1:count
LLVM.API.LLVMAddCallSiteAttribute(newinst, idx, unsafe_load(Attrs, j))
end
Libc.free(Attrs)
end
API.EnzymeCopyMetadata(newinst, inst)
callconv!(newinst, callconv(inst))
push!(replaceAndErase, (inst, newinst))
end
continue
end
newbundles = OperandBundleDef[]
for bunduse in operand_bundles(inst)
bunduse = LLVM.OperandBundleDef(bunduse)
if LLVM.tag_name(bunduse) != "jl_roots"
Expand All @@ -181,7 +223,7 @@ function rewrite_ccalls!(mod::LLVM.Module)
subchanged = false
for lval in LLVM.inputs(bunduse)
llty = value_type(lval)
if !isa(llty, LLVM.PointerType) || LLVM.addrspace(llty) != 10
if isa(llty, LLVM.PointerType)
push!(uservals, lval)
continue
end
Expand Down

0 comments on commit 21b0762

Please sign in to comment.