From 6c2b0d9926ea50f1e4a7215b80278e6d3b75a0bf Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 10 Jun 2024 16:12:08 -0700 Subject: [PATCH] Improve rule arg mixed errors (#1530) * Improve rule arg mixed errors * fixup * improve errs --- src/rules/customrules.jl | 16 +++++++++++++++- src/rules/jitrules.jl | 12 ++++++++++-- src/rules/typeunstablerules.jl | 12 ++++++------ 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 989a733c01..c658850c2e 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -122,11 +122,14 @@ function enzyme_custom_setup_args(B, orig::LLVM.CallInst, gutils::GradientUtils, push!(activity, Ty) - elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg(arg.typ, world) ) + elseif activep == API.DFT_OUT_DIFF || (mode != API.DEM_ForwardMode && active_reg_inner(arg.typ, (), world, #=justActive=#Val(true)) == ActiveState) Ty = Active{arg.typ} llty = convert(LLVMType, Ty) arty = convert(LLVMType, arg.typ; allow_boxed=true) if B !== nothing + if active_reg_inner(arg.typ, (), world, #=justActive=#Val(false)) == MixedState + emit_error(B, orig, "Enzyme: Argument type $(arg.typ) has mixed internal activity types in evaluation of custom rule for $mi. See https://enzyme.mit.edu/julia/stable/faq/#Mixed-activity for more information") + end al0 = al = emit_allocobj!(B, Ty) al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) al = addrspacecast!(B, al, LLVM.PointerType(llty, Derived)) @@ -716,6 +719,17 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, tape_idx = 1+(kwtup!==nothing && !isghostty(kwtup))+(isKWCall && !isghostty(rev_TT.parameters[4])) innerTy = value_type(parameters(llvmf)[tape_idx+(sret !== nothing)+(RT <: Active)]) if innerTy != value_type(tape) + if isabstracttype(TapeT) + msg = sprint() do io + println(io, "Enzyme : mismatch between innerTy $innerTy and tape type $(value_type(tape))") + println(io, "tape_idx=", tape_idx) + println(io, "sret=", sret) + println(io, "RT=", RT) + println(io, "tape=", tape) + println(io, "llvmf=", string(llvmf)) + end + throw(AssertionError(msg)) + end llty = convert(LLVMType, TapeT; allow_boxed=true) al0 = al = emit_allocobj!(B, TapeT) al = bitcast!(B, al, LLVM.PointerType(llty, addrspace(value_type(al)))) diff --git a/src/rules/jitrules.jl b/src/rules/jitrules.jl index 4aaa1c813d..f2f9d27407 100644 --- a/src/rules/jitrules.jl +++ b/src/rules/jitrules.jl @@ -1111,7 +1111,7 @@ for (N, Width) in Iterators.product(0:30, 1:10) eval(func_runtime_iterate_rev(N, Width)) end -function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true) +function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, lookup; sret=nothing, tape=nothing, firstconst=false, endcast=true, firstconst_after_tape=true) width = get_width(gutils) mode = get_mode(gutils) mod = LLVM.parent(LLVM.parent(LLVM.parent(orig))) @@ -1132,7 +1132,7 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, T_jlvalue = LLVM.StructType(LLVM.LLVMType[]) T_prjlvalue = LLVM.PointerType(T_jlvalue, Tracked) - if firstconst + if firstconst && !firstconst_after_tape val = new_from_original(gutils, operands(orig)[start]) if lookup val = lookup_value(gutils, val, B) @@ -1196,6 +1196,14 @@ function generic_setup(orig, func, ReturnType, gutils, start, B::LLVM.IRBuilder, else pushfirst!(vals, unsafe_to_llvm(Val(ReturnType))) end + + if firstconst && firstconst_after_tape + val = new_from_original(gutils, operands(orig)[start]) + if lookup + val = lookup_value(gutils, val, B) + end + pushfirst!(vals, val) + end if mode != API.DEM_ForwardMode uncacheable = get_uncacheable(gutils, orig) diff --git a/src/rules/typeunstablerules.jl b/src/rules/typeunstablerules.jl index 101796401f..36f2798c0c 100644 --- a/src/rules/typeunstablerules.jl +++ b/src/rules/typeunstablerules.jl @@ -181,19 +181,19 @@ function body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primarg end function func_runtime_newstruct_augfwd(N, Width) - primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width) + primargs, _, primtypes, allargs, typeargs, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width; mixed_or_active=true) body = body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) quote - function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, ::Type{NewType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, NewType, $(typeargs...)} + function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{$Width}, ModifiedBetween::Val{MB}, ::Type{NewType}, RT::Val{ReturnType}, $(allargs...))::ReturnType where {ActivityTup, MB, ReturnType, NewType, $(typeargs...)} $body end end end -@generated function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, RT::Val{ReturnType}, ::Type{NewType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, NewType} +@generated function runtime_newstruct_augfwd(activity::Type{Val{ActivityTup}}, width::Val{Width}, ModifiedBetween::Val{MB}, ::Type{NewType}, RT::Val{ReturnType}, allargs...)::ReturnType where {ActivityTup, MB, Width, ReturnType, NewType} N = div(length(allargs)+2, Width+1)-1 - primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs) + primargs, _, primtypes, _, _, wrapped, batchshadowargs, _, active_refs = setup_macro_wraps(false, N, Width, :allargs; mixed_or_active=true) return body_runtime_newstruct_augfwd(N, Width, primtypes, active_refs, primargs, batchshadowargs) end @@ -325,7 +325,7 @@ function common_newstructv_augfwd(offset, B, orig, gutils, normalR, shadowR, tap width = get_width(gutils) - sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false) + sret = generic_setup(orig, runtime_newstruct_augfwd, width == 1 ? Any : AnyArray(Int(width)), gutils, #=start=#offset, B, false; firstconst=true, endcast = false, firstconst_after_tape=true) if width == 1 shadow = sret @@ -369,7 +369,7 @@ function common_newstructv_rev(offset, B, orig, gutils, tape) if !newstruct_common(#=fwd=#false, #=run=#false, offset, B, orig, gutils, #=normalR=#nothing, #=shadowR=#nothing) @assert tape !== C_NULL width = get_width(gutils) - generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape) + generic_setup(orig, runtime_newstruct_rev, Nothing, gutils, #=start=#offset, B, true; firstconst=true, tape, firstconst_after_tape=true) end return nothing