From 8d892e7eaacbc5e7cd98f13bd3b796d57cf029df Mon Sep 17 00:00:00 2001
From: William Moses <gh@wsmoses.com>
Date: Tue, 24 Dec 2024 12:29:15 -0500
Subject: [PATCH] Shadowalloc (#2218)

* Zero shadow alloc gc before primal

* fix

* Fix

* only npointers != 0

* pointerfree

* fix

* Update compiler.jl

* Update compiler.jl

* Update compiler.jl

* fix

* datatype

* Update compiler.jl

* Update compiler.jl
---
 src/compiler.jl          | 113 ++++++++++++++++++++++++++++++++++++---
 src/compiler/optimize.jl |   1 +
 src/llvm/transforms.jl   |  18 +++++++
 3 files changed, 124 insertions(+), 8 deletions(-)

diff --git a/src/compiler.jl b/src/compiler.jl
index e759ea0e89..536fce6096 100644
--- a/src/compiler.jl
+++ b/src/compiler.jl
@@ -594,19 +594,85 @@ function julia_undef_value_for_type(
     throw(AssertionError("Unknown type to val: $(Ty)"))
 end
 
-function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef)
+function create_recursive_stores(B::LLVM.IRBuilder, @nospecialize(Ty::DataType), @nospecialize(prev::LLVM.Value))::Nothing
+    if Base.datatype_pointerfree(Ty)
+        return
+    end
+
+    isboxed_ref = Ref{Bool}()
+    LLVMType = LLVM.LLVMType(ccall(:jl_type_to_llvm, LLVM.API.LLVMTypeRef,
+                (Any, LLVM.Context, Ptr{Bool}), Ty, LLVM.context(), isboxed_ref))
+
+    if !isboxed_ref[]
+        zeroAll = false
+        T_int64 = LLVM.Int64Type()
+        prev = bitcast!(B, prev, LLVM.PointerType(LLVMType, addrspace(value_type(prev))))
+        prev = addrspacecast!(B, prev, LLVM.PointerType(LLVMType, Derived))
+        zero_single_allocation(B, Ty, LLVMType, prev, zeroAll, LLVM.ConstantInt(T_int64, 0); atomic=true)
+    else
+        @assert fieldcount(Ty) != 0
+
+        T_jlvalue = LLVM.StructType(LLVM.LLVMType[])
+        T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
+
+        T_int8 = LLVM.Int8Type()
+        T_int64 = LLVM.Int64Type()
+        
+        T_pint8 = LLVM.PointerType(T_int8)
+
+        prev2 = bitcast!(B, prev, LLVM.PointerType(T_int8, addrspace(value_type(prev))))
+
+        for i in 1:fieldcount(Ty)
+            Ty2 = fieldtype(Ty, i)
+            off = fieldoffset(Ty, i)
+    
+            if Ty2 <: DataType && Base.datatype_pointerfree(Ty2)
+                continue
+            end
+
+            prev3 = inbounds_gep!(
+                B,
+                T_int8,
+                prev2,
+                LLVM.Value[LLVM.ConstantInt(Int64(off))],
+            )
+            
+            fallback = Base.isabstracttype(Ty2) || Ty2 isa Union
+
+            @static if VERSION < v"1.11-"
+                fallback |= Ty2 <: Array
+            else
+                fallback |= Ty2 <: GenericMemory
+            end
+
+            if fallback
+                Ty2 = Any
+                zeroAll = false
+                prev3 = bitcast!(B, prev3, LLVM.PointerType(T_prjlvalue, addrspace(value_type(prev3))))
+                if addrspace(value_type(prev3)) != Derived
+                  prev3 = addrspacecast!(B, prev3, LLVM.PointerType(T_prjlvalue, Derived))
+                end
+                zero_single_allocation(B, Ty2, T_prjlvalue, prev3, zeroAll, LLVM.ConstantInt(T_int64, 0); atomic=true) 
+            else
+                create_recursive_stores(B, Ty2, prev3)
+            end
+        end
+    end
+end
+
+function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradientUtilsRef, Orig::LLVM.API.LLVMValueRef, idx::UInt64, prev::API.LLVMValueRef)
     V = LLVM.CallInst(V)
     gutils = GradientUtils(gutils)
     mode = get_mode(gutils)
+    has, Ty, byref = abs_typeof(V)
+    if !has
+        throw(AssertionError("$(string(fn))\n Allocation could not have its type statically determined $(string(V))"))
+    end
     if mode == API.DEM_ReverseModePrimal ||
        mode == API.DEM_ReverseModeGradient ||
        mode == API.DEM_ReverseModeCombined
         fn = LLVM.parent(LLVM.parent(V))
         world = enzyme_extract_world(fn)
-        has, Ty, byref = abs_typeof(V)
-        if !has
-            throw(AssertionError("$(string(fn))\n Allocation could not have its type statically determined $(string(V))"))
-        end
         rt = active_reg_inner(Ty, (), world)
         if rt == ActiveState || rt == MixedState
             B = LLVM.IRBuilder()
@@ -614,6 +680,26 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie
             operands(V)[3] = unsafe_to_llvm(B, Base.RefValue{Ty})
         end
     end
+   
+    if mode == API.DEM_ForwardMode
+        # Zero any jlvalue_t inner elements of preceeding allocation.
+        # Specifically in forward mode, you will first run the original allocation,
+        # then all shadow allocations. These allocations will thus all run before
+        # any value may store into them. For example, as follows:
+        #   %orig = julia.gc_alloc(...)
+        #   %"orig'" = julia.gcalloc(...)
+        #   store orig[0] = jlvaluet
+        #   store "orig'"[0] = jlvaluet'
+        # As a result, by the time of the subsequent GC allocation, the memory in the preceeding
+        # allocation might be undefined, and trigger a GC error. To avoid this,
+        # we will explicitly zero the GC'd fields of the previous allocation.
+        prev = LLVM.Instruction(prev)
+        B = LLVM.IRBuilder()
+        position!(B, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(prev)))
+
+        create_recursive_stores(B, Ty, prev)
+    end
+
     nothing
 end
 
@@ -671,7 +757,7 @@ function zero_allocation(B::LLVM.API.LLVMBuilderRef, LLVMType::LLVM.API.LLVMType
     return nothing
 end
 
-function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::DataType), @nospecialize(LLVMType::LLVM.LLVMType), @nospecialize(nobj::LLVM.Value), zeroAll::Bool, @nospecialize(idx::LLVM.Value))
+function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::DataType), @nospecialize(LLVMType::LLVM.LLVMType), @nospecialize(nobj::LLVM.Value), zeroAll::Bool, @nospecialize(idx::LLVM.Value); write_barrier=false, atomic=false)
     T_jlvalue = LLVM.StructType(LLVM.LLVMType[])
     T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked)
     T_prjlvalue_UT = LLVM.PointerType(T_jlvalue)
@@ -682,6 +768,7 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D
         jlType,
     )]
 
+    addedvals = LLVM.Value[]
     while length(todo) != 0
         path, ty, jlty = popfirst!(todo)
         if isa(ty, LLVM.PointerType)
@@ -689,12 +776,18 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D
                 loc = gep!(builder, LLVMType, nobj, path)
                 mod = LLVM.parent(LLVM.parent(Base.position(builder)))
                 fill_val = unsafe_nothing_to_llvm(mod)
+                push!(addedvals, fill_val)
                 loc = bitcast!(
                     builder,
                     loc,
                     LLVM.PointerType(T_prjlvalue, addrspace(value_type(loc))),
                 )
-                store!(builder, fill_val, loc)
+                st = store!(builder, fill_val, loc)
+                if atomic
+                    ordering!(st, LLVM.API.LLVMAtomicOrderingRelease)
+                    syncscope!(st, LLVM.SyncScope("singlethread"))
+                    metadata(st)["enzymejl_atomicgc"] = LLVM.MDNode(LLVM.Metadata[])
+                end
             elseif zeroAll
                 loc = gep!(builder, LLVMType, nobj, path)
                 store!(builder, LLVM.null(ty), loc)
@@ -741,6 +834,10 @@ function zero_single_allocation(builder::LLVM.IRBuilder, @nospecialize(jlType::D
             continue
         end
     end
+    if length(addedvals) != 0 && write_barrier
+        pushfirst!(addedvals, get_base_and_offset(nobj; offsetAllowed=false, inttoptr=false)[1])
+        emit_writebarrier!(builder, addedvals)
+    end
     return nothing
 
 end
@@ -1127,7 +1224,7 @@ function __init__()
         @cfunction(
             shadow_alloc_rewrite,
             Cvoid,
-            (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef)
+            (LLVM.API.LLVMValueRef, API.EnzymeGradientUtilsRef, LLVM.API.LLVMValueRef, UInt64, LLVM.API.LLVMValueRef)
         )
     )
     register_alloc_rules()
diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl
index f9769881a4..d3c4375abf 100644
--- a/src/compiler/optimize.jl
+++ b/src/compiler/optimize.jl
@@ -652,6 +652,7 @@ function addOptimizationPasses!(pm::LLVM.ModulePassManager, tm::LLVM.TargetMachi
     jl_inst_simplify!(pm)
     jump_threading!(pm)
     dead_store_elimination!(pm)
+    add!(pm, FunctionPass("SafeAtomicToRegularStore", safe_atomic_to_regular_store!))
 
     # More dead allocation (store) deletion before loop optimization
     # consider removing this:
diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl
index 7af6690138..3402d3232f 100644
--- a/src/llvm/transforms.jl
+++ b/src/llvm/transforms.jl
@@ -2446,3 +2446,21 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine)
     eraseInst(mod, func)
 end
 
+function safe_atomic_to_regular_store!(f::LLVM.Function)
+    changed = false
+    for bb in blocks(f), inst in instructions(bb)
+        if isa(inst, LLVM.StoreInst)
+            continue
+        end
+        if !haskey(metadata(inst), "enzymejl_atomicgc")
+            continue
+        end
+        Base.delete!(metadata(inst), "enzymejl_atomicgc")
+        syncscope!(inst, LLVM.SyncScope("system"))
+        ordering!(inst, LLVM.API.LLVMAtomicOrderingNotAtomic)
+        changed = true
+    end
+    return changed
+end
+
+