Skip to content

Commit

Permalink
starting batching
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 13, 2024
1 parent fdb235f commit 7669a1e
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 13 deletions.
119 changes: 106 additions & 13 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3429,7 +3429,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
else
push!(args_activity, API.DFT_OUT_DIFF)
end
elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc
elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc || T <: MixedDuplicated || T <: BatchMixedDuplicated
push!(args_activity, API.DFT_DUP_ARG)
elseif T <: DuplicatedNoNeed || T<: BatchDuplicatedNoNeed
push!(args_activity, API.DFT_DUP_NONEED)
Expand Down Expand Up @@ -3613,7 +3613,6 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,

isboxed = GPUCompiler.deserves_argbox(source_typ)
llvmT = isboxed ? T_prjlvalue : convert(LLVMType, source_typ)

push!(T_wrapperargs, llvmT)

if T <: Const || T <: BatchDuplicatedFunc
Expand Down Expand Up @@ -3642,6 +3641,11 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
if is_adjoint && i != 1
push!(ActiveRetTypes, Nothing)
end
elseif T <: MixedDuplicated || T <: BatchMixedDuplicated
push!(T_wrapperargs, LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue)))
if is_adjoint && i != 1
push!(ActiveRetTypes, Nothing)
end
else
error("calling convention should be annotated, got $T")
end
Expand Down Expand Up @@ -3824,7 +3828,23 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
if isghostty(T′) || Core.Compiler.isconstType(T′)
continue
end
push!(realparms, params[i])

isboxed = GPUCompiler.deserves_argbox(T′)

llty = value_type(params[i])

convty = convert(LLVMType, T′; allow_boxed=true)

if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType))
al = emit_allocobj!(builder, Base.RefValue{T′})
al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al))))
store!(builder, params[i], al)
al = addrspacecast!(builder, al, LLVM.PointerType(llty, Derived))
push!(realparms, al)
else
push!(realparms, params[i])
end

i += 1
if T <: Const
elseif T <: Active
Expand Down Expand Up @@ -3852,6 +3872,25 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType,
elseif T <: Duplicated || T <: DuplicatedNoNeed
push!(realparms, params[i])
i += 1
elseif T <: MixedDuplicated || T <: BatchMixedDuplicated
pv = params[i]
isboxed = GPUCompiler.deserves_argbox(T′)

resty = isboxed ? llty : LLVM.PointerType(llty, Derived)

ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, LLVM.PointerType(llty, Derived))))
for idx in 1:width
pv = (width == 1) ? params[i] : extract_value!(builder, params[i], idx-1)
pv = bitcast!(builder, pv, LLVM.PointerType(llty, addrspace(value_type(pv))))
pv = addrspacecast!(builder, pv, LLVM.PointerType(llty, Derived))
if isboxed
pv = load!(builder, llty, pv, "mixedboxload")
end
ival = (width == 1 ) ? pv : insert_value!(builder, ival, pv, idx-1)
end

push!(realparms, ival)
i += 1
elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed
isboxed = GPUCompiler.deserves_argbox(NTuple{width, T′})
val = params[i]
Expand Down Expand Up @@ -4382,6 +4421,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function

# generate the wrapper function type & definition
wrapper_types = LLVM.LLVMType[]
wrapper_attrs = Vector{LLVM.Attribute}[]
_, sret, returnRoots = get_return_info(actualRetType)
sret_union = is_sret_union(actualRetType)

Expand Down Expand Up @@ -4416,31 +4456,44 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function

if swiftself
push!(wrapper_types, value_type(parameters(entry_f)[1+sret+returnRoots]))
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("swiftself")])
end

boxedArgs = Set{Int}()
loweredArgs = Set{Int}()
raisedArgs = Set{Int}()

for arg in args
typ = arg.codegen.typ
if GPUCompiler.deserves_argbox(arg.typ)
push!(boxedArgs, arg.arg_i)
push!(wrapper_types, typ)
push!(wrapper_attrs, LLVM.Attribute[])
elseif arg.cc != GPUCompiler.BITS_REF
push!(wrapper_types, typ)
if TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated
push!(boxedArgs, arg.arg_i)
push!(raisedArgs, arg.arg_i)
push!(wrapper_types, LLVM.PointerType(typ, Derived))
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")])
else
push!(wrapper_types, typ)
push!(wrapper_attrs, LLVM.Attribute[])
end
else
# bits ref, and not boxed
# if TT.parameters[arg.arg_i] <: Const
# push!(boxedArgs, arg.arg_i)
# push!(wrapper_types, typ)
# else
if TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated
push!(boxedArgs, arg.arg_i)
push!(wrapper_types, typ)
push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")])
else
push!(wrapper_types, eltype(typ))
push!(wrapper_attrs, LLVM.Attribute[])
push!(loweredArgs, arg.arg_i)
# end
end
end
end

if length(loweredArgs) == 0 && !sret && !sret_union
if length(loweredArgs) == 0 && length(raisedArgs) == 0 && !sret && !sret_union
return entry_f, returnRoots, boxedArgs, loweredArgs
end

Expand All @@ -4461,8 +4514,10 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
end
push!(function_attributes(wrapper_f), EnumAttribute("returns_twice"))
push!(function_attributes(entry_f), EnumAttribute("returns_twice"))
if swiftself
push!(parameter_attributes(wrapper_f, 1), EnumAttribute("swiftself"))
for (i, v) in enumerate(wrapper_attrs)
for attr in v
push!(parameter_attributes(wrapper_f, i), attr)
end
end

seen = TypeTreeTable()
Expand All @@ -4488,6 +4543,12 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
parm = ops[arg.codegen.i]
if arg.arg_i in loweredArgs
push!(nops, load!(builder, convert(LLVMType, arg.typ), parm))
elseif arg.arg_i in raisedArgs
obj = emit_allocobj!(builder, arg.typ)
bc = bitcast!(builder, obj, LLVM.PointerType(value_type(parm), addrspace(value_type(obj))))
store!(builder, parm, bc)
addr = addrspacecast!(builder, bc, LLVM.PointerType(value_type(parm), Derived))
push!(nops, addr)
else
push!(nops, parm)
end
Expand Down Expand Up @@ -4572,6 +4633,13 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(arg.typ, ctx, dl, seen))))
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ)))))
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
elseif arg.arg_i in raisedArgs
wrapparm = load!(builder, convert(LLVMType, arg.typ), wrapparm)
ctx = LLVM.context(wrapparm)
push!(wrapper_args, wrapparm)
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(Base.RefValue{arg.typ}, ctx, dl, seen))))
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ)))))
push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
else
push!(wrapper_args, wrapparm)
for attr in collect(parameter_attributes(entry_f, arg.codegen.i))
Expand Down Expand Up @@ -4651,6 +4719,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
elseif LLVM.return_type(entry_ft) == LLVM.VoidType()
ret!(builder)
else
ctx = LLVM.context(wrapper_f)
push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen))))
push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType)))))
push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF))))
Expand Down Expand Up @@ -4712,7 +4781,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0
msg = sprint() do io
println(io, string(mod))
println(io, LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction))
println(io, LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction))
println(io, string(wrapper_f))
println(io, "parmsRemoved=", parmsRemoved, " retRemoved=", retRemoved, " prargs=", prargs)
println(io, "Broken function")
Expand Down Expand Up @@ -5991,6 +6060,30 @@ end
push!(ActiveRetTypes, Nothing)
end
push!(ccexprs, argexpr)
elseif T <: MixedDuplicated
if RawCall
argexpr = argexprs[i]
i+=1
else
argexpr = Expr(:., expr, QuoteNode(:dval))
end
push!(types, Any)
if is_adjoint
push!(ActiveRetTypes, Nothing)
end
push!(ccexprs, argexpr)
elseif T <: BatchMixedDuplicated
if RawCall
argexpr = argexprs[i]
i+=1
else
argexpr = Expr(:., expr, QuoteNode(:dval))
end
push!(types, NTuple{width, Base.RefValue})
if is_adjoint
push!(ActiveRetTypes, Nothing)
end
push!(ccexprs, argexpr)
else
error("calling convention should be annotated, got $T")
end
Expand Down
1 change: 1 addition & 0 deletions test/abi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,4 @@ abssum(x) = sum(abs2, x);

end

include("usermixed.jl")
91 changes: 91 additions & 0 deletions test/usermixed.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using Enzyme
using Test

function user_mixfnc(tup)
return tup[1] * tup[2][1]
end

@testset "MixedDuplicated struct call" begin
tup = (2.7, [3.14])
dtup = Ref((0.0, [0.0]))

res = autodiff(Reverse, user_mixfnc, Active, MixedDuplicated(tup, dtup))
@test dtup[][1] 3.14
@test dtup[][2] [2.7]
end


function user_mixfnc_byref(out, tup)
out[] = tup[1] * tup[2][1]
return nothing
end

@testset "Batch MixedDuplicated struct call" begin
tup = (2.7, [3.14])
dtup = (Ref((0.0, [0.0])), Ref((0.0, [0.0])))
out = Ref(0.0)
dout = (Ref(1.0), Ref(3.0))
res = autodiff(Reverse, user_mixfnc_byref, Const, BatchDuplicated(out, dout), BatchMixedDuplicated(tup, dtup))
@test dtup[1][][1] 3.14
@test dtup[1][][2] [2.7]
@test dtup[1][][1] 3*3.14
@test dtup[1][][2] [3*2.7]
end

function mix_square(x)
return x * x
end

@testset "MixedDuplicated float64 call" begin
tup = 2.7
dtup = Ref(0.0)
res = autodiff(Reverse, mix_square, Active, MixedDuplicated(tup, dtup))[1]
@test res[1] == (nothing,)
@test dtup[] 2 * 2.7
end


function mix_square_byref(out, x)
out[] = x * x
return nothing
end

@testset "BatchMixedDuplicated float64 call" begin
tup = 2.7
dtup = (Ref(0.0), Ref(0.0))
out = Ref(0.0)
dout = (Ref(1.0), Ref(3.0))
res = autodiff(Reverse, mix_square, Const, BatchDuplicated(out, dout), BatchMixedDuplicated(tup, dtup))[1]
@test res[1] == (nothing,)
@test dtup[1][] 2 * 2.7
@test dtup[1][] 3 * 2 * 2.7
end

function mix_ar(x)
return x[1] * x[2]
end

@testset "MixedDuplicated vector{float64} call" begin
tup = [2.7, 3.14]
dtup = Ref([0.0, 0.0])
res = autodiff(Reverse, mix_ar, Active, MixedDuplicated(tup, dtup))
@test res[1] == (nothing,)
@test dtup[] [3.14, 2.7]
end


function mix_ar_byref(out, x)
out[] = x[1] * x[2]
return nothing
end

@testset "BatchMixedDuplicated vector{float64} call" begin
tup = [2.7, 3.14]
dtup = (Ref([0.0, 0.0]), Ref([0.0, 0.0]))
out = Ref(0.0)
dout = (Ref(1.0), Ref(3.0))
res = autodiff(Reverse, mix_ar, Const, BatchDuplicated(out, dout), BatchMixedDuplicated(tup, dtup))
@test res[1] == (nothing,)
@test dtup[1][] [3.14, 2.7]
@test dtup[2][] [3*3.14, 3*2.7]
end

0 comments on commit 7669a1e

Please sign in to comment.