From 7669a1e5cd6e8fcb467a0f19acff7b16c935522b Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 12 Jun 2024 19:42:17 -0700 Subject: [PATCH] starting batching --- src/compiler.jl | 119 +++++++++++++++++++++++++++++++++++++++++----- test/abi.jl | 1 + test/usermixed.jl | 91 +++++++++++++++++++++++++++++++++++ 3 files changed, 198 insertions(+), 13 deletions(-) create mode 100644 test/usermixed.jl diff --git a/src/compiler.jl b/src/compiler.jl index daef61e583..658a5b925d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3429,7 +3429,7 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr else push!(args_activity, API.DFT_OUT_DIFF) end - elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc + elseif T <: Duplicated || T<: BatchDuplicated || T<: BatchDuplicatedFunc || T <: MixedDuplicated || T <: BatchMixedDuplicated push!(args_activity, API.DFT_DUP_ARG) elseif T <: DuplicatedNoNeed || T<: BatchDuplicatedNoNeed push!(args_activity, API.DFT_DUP_NONEED) @@ -3613,7 +3613,6 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, isboxed = GPUCompiler.deserves_argbox(source_typ) llvmT = isboxed ? T_prjlvalue : convert(LLVMType, source_typ) - push!(T_wrapperargs, llvmT) if T <: Const || T <: BatchDuplicatedFunc @@ -3642,6 +3641,11 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if is_adjoint && i != 1 push!(ActiveRetTypes, Nothing) end + elseif T <: MixedDuplicated || T <: BatchMixedDuplicated + push!(T_wrapperargs, LLVM.LLVMType(API.EnzymeGetShadowType(width, T_prjlvalue))) + if is_adjoint && i != 1 + push!(ActiveRetTypes, Nothing) + end else error("calling convention should be annotated, got $T") end @@ -3824,7 +3828,23 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, if isghostty(T′) || Core.Compiler.isconstType(T′) continue end - push!(realparms, params[i]) + + isboxed = GPUCompiler.deserves_argbox(T′) + + llty = value_type(params[i]) + + convty = convert(LLVMType, T′; allow_boxed=true) + + if (T <: MixedDuplicated || T <: BatchMixedDuplicated) && !isboxed # && (isa(llty, LLVM.ArrayType) || isa(llty, LLVM.StructType)) + al = emit_allocobj!(builder, Base.RefValue{T′}) + al = bitcast!(builder, al, LLVM.PointerType(llty, addrspace(value_type(al)))) + store!(builder, params[i], al) + al = addrspacecast!(builder, al, LLVM.PointerType(llty, Derived)) + push!(realparms, al) + else + push!(realparms, params[i]) + end + i += 1 if T <: Const elseif T <: Active @@ -3852,6 +3872,25 @@ function create_abi_wrapper(enzymefn::LLVM.Function, TT, rettype, actualRetType, elseif T <: Duplicated || T <: DuplicatedNoNeed push!(realparms, params[i]) i += 1 + elseif T <: MixedDuplicated || T <: BatchMixedDuplicated + pv = params[i] + isboxed = GPUCompiler.deserves_argbox(T′) + + resty = isboxed ? llty : LLVM.PointerType(llty, Derived) + + ival = UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, LLVM.PointerType(llty, Derived)))) + for idx in 1:width + pv = (width == 1) ? params[i] : extract_value!(builder, params[i], idx-1) + pv = bitcast!(builder, pv, LLVM.PointerType(llty, addrspace(value_type(pv)))) + pv = addrspacecast!(builder, pv, LLVM.PointerType(llty, Derived)) + if isboxed + pv = load!(builder, llty, pv, "mixedboxload") + end + ival = (width == 1 ) ? pv : insert_value!(builder, ival, pv, idx-1) + end + + push!(realparms, ival) + i += 1 elseif T <: BatchDuplicated || T <: BatchDuplicatedNoNeed isboxed = GPUCompiler.deserves_argbox(NTuple{width, T′}) val = params[i] @@ -4382,6 +4421,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function # generate the wrapper function type & definition wrapper_types = LLVM.LLVMType[] + wrapper_attrs = Vector{LLVM.Attribute}[] _, sret, returnRoots = get_return_info(actualRetType) sret_union = is_sret_union(actualRetType) @@ -4416,31 +4456,44 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if swiftself push!(wrapper_types, value_type(parameters(entry_f)[1+sret+returnRoots])) + push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("swiftself")]) end boxedArgs = Set{Int}() loweredArgs = Set{Int}() + raisedArgs = Set{Int}() for arg in args typ = arg.codegen.typ if GPUCompiler.deserves_argbox(arg.typ) push!(boxedArgs, arg.arg_i) push!(wrapper_types, typ) + push!(wrapper_attrs, LLVM.Attribute[]) elseif arg.cc != GPUCompiler.BITS_REF - push!(wrapper_types, typ) + if TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated + push!(boxedArgs, arg.arg_i) + push!(raisedArgs, arg.arg_i) + push!(wrapper_types, LLVM.PointerType(typ, Derived)) + push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")]) + else + push!(wrapper_types, typ) + push!(wrapper_attrs, LLVM.Attribute[]) + end else # bits ref, and not boxed - # if TT.parameters[arg.arg_i] <: Const - # push!(boxedArgs, arg.arg_i) - # push!(wrapper_types, typ) - # else + if TT.parameters[arg.arg_i] <: MixedDuplicated || TT.parameters[arg.arg_i] <: BatchMixedDuplicated + push!(boxedArgs, arg.arg_i) + push!(wrapper_types, typ) + push!(wrapper_attrs, LLVM.Attribute[EnumAttribute("noalias")]) + else push!(wrapper_types, eltype(typ)) + push!(wrapper_attrs, LLVM.Attribute[]) push!(loweredArgs, arg.arg_i) - # end + end end end - if length(loweredArgs) == 0 && !sret && !sret_union + if length(loweredArgs) == 0 && length(raisedArgs) == 0 && !sret && !sret_union return entry_f, returnRoots, boxedArgs, loweredArgs end @@ -4461,8 +4514,10 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function end push!(function_attributes(wrapper_f), EnumAttribute("returns_twice")) push!(function_attributes(entry_f), EnumAttribute("returns_twice")) - if swiftself - push!(parameter_attributes(wrapper_f, 1), EnumAttribute("swiftself")) + for (i, v) in enumerate(wrapper_attrs) + for attr in v + push!(parameter_attributes(wrapper_f, i), attr) + end end seen = TypeTreeTable() @@ -4488,6 +4543,12 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function parm = ops[arg.codegen.i] if arg.arg_i in loweredArgs push!(nops, load!(builder, convert(LLVMType, arg.typ), parm)) + elseif arg.arg_i in raisedArgs + obj = emit_allocobj!(builder, arg.typ) + bc = bitcast!(builder, obj, LLVM.PointerType(value_type(parm), addrspace(value_type(obj)))) + store!(builder, parm, bc) + addr = addrspacecast!(builder, bc, LLVM.PointerType(value_type(parm), Derived)) + push!(nops, addr) else push!(nops, parm) end @@ -4572,6 +4633,13 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(arg.typ, ctx, dl, seen)))) push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))))) push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) + elseif arg.arg_i in raisedArgs + wrapparm = load!(builder, convert(LLVMType, arg.typ), wrapparm) + ctx = LLVM.context(wrapparm) + push!(wrapper_args, wrapparm) + push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzyme_type", string(typetree(Base.RefValue{arg.typ}, ctx, dl, seen)))) + push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))))) + push!(parameter_attributes(wrapper_f, arg.codegen.i-sret-returnRoots), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) else push!(wrapper_args, wrapparm) for attr in collect(parameter_attributes(entry_f, arg.codegen.i)) @@ -4651,6 +4719,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function elseif LLVM.return_type(entry_ft) == LLVM.VoidType() ret!(builder) else + ctx = LLVM.context(wrapper_f) push!(return_attributes(wrapper_f), StringAttribute("enzyme_type", string(typetree(actualRetType, ctx, dl, seen)))) push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(actualRetType))))) push!(return_attributes(wrapper_f), StringAttribute("enzymejl_parmtype_ref", string(UInt(GPUCompiler.BITS_REF)))) @@ -4712,7 +4781,7 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMReturnStatusAction) != 0 msg = sprint() do io println(io, string(mod)) - println(io, LVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction)) + println(io, LLVM.API.LLVMVerifyFunction(wrapper_f, LLVM.API.LLVMPrintMessageAction)) println(io, string(wrapper_f)) println(io, "parmsRemoved=", parmsRemoved, " retRemoved=", retRemoved, " prargs=", prargs) println(io, "Broken function") @@ -5991,6 +6060,30 @@ end push!(ActiveRetTypes, Nothing) end push!(ccexprs, argexpr) + elseif T <: MixedDuplicated + if RawCall + argexpr = argexprs[i] + i+=1 + else + argexpr = Expr(:., expr, QuoteNode(:dval)) + end + push!(types, Any) + if is_adjoint + push!(ActiveRetTypes, Nothing) + end + push!(ccexprs, argexpr) + elseif T <: BatchMixedDuplicated + if RawCall + argexpr = argexprs[i] + i+=1 + else + argexpr = Expr(:., expr, QuoteNode(:dval)) + end + push!(types, NTuple{width, Base.RefValue}) + if is_adjoint + push!(ActiveRetTypes, Nothing) + end + push!(ccexprs, argexpr) else error("calling convention should be annotated, got $T") end diff --git a/test/abi.jl b/test/abi.jl index 7371af504e..8d4251bb70 100644 --- a/test/abi.jl +++ b/test/abi.jl @@ -442,3 +442,4 @@ abssum(x) = sum(abs2, x); end +include("usermixed.jl") \ No newline at end of file diff --git a/test/usermixed.jl b/test/usermixed.jl new file mode 100644 index 0000000000..b8f7741f91 --- /dev/null +++ b/test/usermixed.jl @@ -0,0 +1,91 @@ +using Enzyme +using Test + +function user_mixfnc(tup) + return tup[1] * tup[2][1] +end + +@testset "MixedDuplicated struct call" begin + tup = (2.7, [3.14]) + dtup = Ref((0.0, [0.0])) + + res = autodiff(Reverse, user_mixfnc, Active, MixedDuplicated(tup, dtup)) + @test dtup[][1] ≈ 3.14 + @test dtup[][2] ≈ [2.7] +end + + +function user_mixfnc_byref(out, tup) + out[] = tup[1] * tup[2][1] + return nothing +end + +@testset "Batch MixedDuplicated struct call" begin + tup = (2.7, [3.14]) + dtup = (Ref((0.0, [0.0])), Ref((0.0, [0.0]))) + out = Ref(0.0) + dout = (Ref(1.0), Ref(3.0)) + res = autodiff(Reverse, user_mixfnc_byref, Const, BatchDuplicated(out, dout), BatchMixedDuplicated(tup, dtup)) + @test dtup[1][][1] ≈ 3.14 + @test dtup[1][][2] ≈ [2.7] + @test dtup[1][][1] ≈ 3*3.14 + @test dtup[1][][2] ≈ [3*2.7] +end + +function mix_square(x) + return x * x +end + +@testset "MixedDuplicated float64 call" begin + tup = 2.7 + dtup = Ref(0.0) + res = autodiff(Reverse, mix_square, Active, MixedDuplicated(tup, dtup))[1] + @test res[1] == (nothing,) + @test dtup[] ≈ 2 * 2.7 +end + + +function mix_square_byref(out, x) + out[] = x * x + return nothing +end + +@testset "BatchMixedDuplicated float64 call" begin + tup = 2.7 + dtup = (Ref(0.0), Ref(0.0)) + out = Ref(0.0) + dout = (Ref(1.0), Ref(3.0)) + res = autodiff(Reverse, mix_square, Const, BatchDuplicated(out, dout), BatchMixedDuplicated(tup, dtup))[1] + @test res[1] == (nothing,) + @test dtup[1][] ≈ 2 * 2.7 + @test dtup[1][] ≈ 3 * 2 * 2.7 +end + +function mix_ar(x) + return x[1] * x[2] +end + +@testset "MixedDuplicated vector{float64} call" begin + tup = [2.7, 3.14] + dtup = Ref([0.0, 0.0]) + res = autodiff(Reverse, mix_ar, Active, MixedDuplicated(tup, dtup)) + @test res[1] == (nothing,) + @test dtup[] ≈ [3.14, 2.7] +end + + +function mix_ar_byref(out, x) + out[] = x[1] * x[2] + return nothing +end + +@testset "BatchMixedDuplicated vector{float64} call" begin + tup = [2.7, 3.14] + dtup = (Ref([0.0, 0.0]), Ref([0.0, 0.0])) + out = Ref(0.0) + dout = (Ref(1.0), Ref(3.0)) + res = autodiff(Reverse, mix_ar, Const, BatchDuplicated(out, dout), BatchMixedDuplicated(tup, dtup)) + @test res[1] == (nothing,) + @test dtup[1][] ≈ [3.14, 2.7] + @test dtup[2][] ≈ [3*3.14, 3*2.7] +end