From 5609c7edf976717f2678dc5d300d1a8bccf8bb64 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 28 May 2024 11:48:53 +0200 Subject: [PATCH] Nice union{} error (#1479) * Nice union{} error * fixup --- src/Enzyme.jl | 4 +-- src/compiler.jl | 79 +++++++++++++++++++++++++++++++++--------------- test/runtests.jl | 9 ++++++ 3 files changed, 66 insertions(+), 26 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index c75508cd77..911d1801ad 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -230,7 +230,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) end if A <: Active - if !allocatedinline(rt) || rt isa Union + if (!allocatedinline(rt) || rt isa Union) && rt != Union{} forward, adjoint = Enzyme.Compiler.thunk(Val(world), FA, Duplicated{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), Val(width), ModifiedBetween, #=ReturnPrimal=#Val(ReturnPrimal), #=ShadowInit=#Val(true), RABI) res = forward(f, args...) tape = res[1] @@ -244,7 +244,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) throw(ErrorException("Duplicated Returns not yet handled")) end - if A <: Active && rt <: Complex + if (A <: Active && rt <: Complex) && rt != Union{} if Holomorphic seen = IdDict() seen2 = IdDict() diff --git a/src/compiler.jl b/src/compiler.jl index 8e897c2526..8cbac14f94 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -577,6 +577,10 @@ struct AdjointThunk{PT, FA, RT, TT, Width, TapeType} <: AbstractThunk{FA, RT, TT adjoint::PT end +struct PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World} <: AbstractThunk{FA, RT, TT, Width} + adjoint::PT +end + @inline return_type(::AbstractThunk{FA, RT}) where {FA, RT} = RT @inline return_type(::Type{AugmentedForwardThunk{PT, FA, RT, TT, Width, ReturnPrimal, TapeType}}) where {PT, FA, RT, TT, Width, ReturnPrimal, TapeType} = RT @@ -5277,7 +5281,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; cf = LLVM.called_operand(tmp) if isa(cf, LLVM.Function) nm = LLVM.name(cf) - if nm == "gpu_signal_exception" || nm == "gpu_report_exception" + if nm == "gpu_signal_exception" || nm == "gpu_report_exception" || nm == "ijl_throw" || nm == "jl_throw" shouldemit = false break end @@ -5433,6 +5437,9 @@ struct CompileResult{AT, PT} TapeType::Type end +@inline (thunk::PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World})(fn::FA, args...) where {PT, FA, RT, TT, Width, ReturnPrimal, World} = +enzyme_call(Val(false), thunk.adjoint, PrimalErrorThunk{PT, FA, RT, TT, Width, ReturnPrimal, World}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) + @inline (thunk::CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal})(fn::FA, args...) where {PT, FA, Width, RT, TT, ReturnPrimal} = enzyme_call(Val(false), thunk.adjoint, CombinedAdjointThunk{PT, FA, RT, TT, Width, ReturnPrimal}, Val(Width), Val(ReturnPrimal), TT, RT, fn, Cvoid, args...) @@ -5536,7 +5543,9 @@ end end @inline function default_adjoint(T) - if T <: AbstractFloat + if T == Union{} + return nothing + elseif T <: AbstractFloat return one(T) elseif T <: Complex error("Attempted to use automatic pullback (differential return value) deduction on a either a type unstable function returning an active complex number, or autodiff_deferred returning an active complex number. For the first case, please type stabilize your code, e.g. by specifying autodiff(Reverse, f->f(x)::Complex, ...). For the second case, please use regular non-deferred autodiff") @@ -5559,7 +5568,7 @@ end JuliaContext() do ctx F = eltype(FA) - is_forward = CC <: AugmentedForwardThunk || CC <: ForwardModeThunk + is_forward = CC <: AugmentedForwardThunk || CC <: ForwardModeThunk || CC <: PrimalErrorThunk is_adjoint = CC <: AdjointThunk || CC <: CombinedAdjointThunk is_split = CC <: AdjointThunk || CC <: AugmentedForwardThunk needs_tape = CC <: AdjointThunk @@ -5569,23 +5578,33 @@ end argtypes = DataType[argtt.parameters...] argexprs = Union{Expr, Symbol}[:(args[$i]) for i in 1:N] - if !RawCall + if false && CC <: PrimalErrorThunk + primargs = [quote + convert($(eltype(T)), $(argexprs[i]).val) + end for (i, T) in enumerate(argtypes)] + return quote + fn.val($(primargs...)) + error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up") + end + end + + if !RawCall && !(CC <: PrimalErrorThunk) if rettype <: Active if length(argtypes) + is_adjoint + needs_tape != length(argexprs) return quote - throw(MethodError($CC($fptr), $args)) + throw(MethodError($CC(fptr), $args)) end end elseif rettype <: Const if length(argtypes) + needs_tape != length(argexprs) return quote - throw(MethodError($CC($fptr), $args)) + throw(MethodError($CC(fptr), $args)) end end else if length(argtypes) + needs_tape != length(argexprs) return quote - throw(MethodError($CC($fptr), $args)) + throw(MethodError($CC(fptr), $args)) end end end @@ -5593,8 +5612,10 @@ end types = DataType[] - if eltype(rettype) === Union{} - error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up") + if eltype(rettype) === Union{} && false + return quote + error("Function to differentiate is guaranteed to return an error and doesn't make sense to autodiff. Giving up") + end end if !(rettype <: Const) && (isghostty(eltype(rettype)) || Core.Compiler.isconstType(eltype(rettype)) || eltype(rettype) === DataType) rrt = eltype(rettype) @@ -5665,7 +5686,9 @@ end end continue end - + if CC <: PrimalErrorThunk + continue + end if T <: Active if is_adjoint if width == 1 @@ -5752,8 +5775,10 @@ end end push!(sret_types, NT) end - - @assert i == length(argexprs)+1 + + if !(CC <: PrimalErrorThunk) + @assert i == length(argexprs)+1 + end # Tape if CC <: AugmentedForwardThunk @@ -5785,7 +5810,7 @@ end T_void = convert(LLVMType, Nothing) - combinedReturn = Tuple{sret_types...} + combinedReturn = (CC <: PrimalErrorThunk && eltype(rettype) == Union{}) ? Union{} : Tuple{sret_types...} if any(any_jltypes(convert(LLVM.LLVMType, T; allow_boxed=true)) for T in sret_types) combinedReturn = AnonymousStruct(combinedReturn) end @@ -6003,29 +6028,30 @@ end params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) - sig = Tuple{eltype(FA), map(eltype, TT.parameters)...} - interp = GPUCompiler.get_interpreter(tmp_job) # TODO check compile return here, early # rrt = Core.Compiler.return_type(f, primal.tt) # nothing rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any) + rrt = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype + + run_enzyme = true if rrt == Union{} - estr = "Function to differentiate `$mi` is guaranteed to return an error and doesn't make sense to autodiff. Giving up" - return quote - error($estr) - end + run_enzyme = false + A = Const end - if !(A <: Const) && guaranteed_const_nongen(rrt, World) + if run_enzyme && !(A <: Const) && guaranteed_const_nongen(rrt, World) estr = "Return type `$rrt` not marked Const, but type is guaranteed to be constant" return quote error($estr) end end - rt2 = if A isa UnionAll + rt2 = if !run_enzyme + Const{rrt} + elseif A isa UnionAll A{rrt} else @assert A isa DataType @@ -6034,7 +6060,7 @@ end A end - params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) + params = Compiler.EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, rt2, run_enzyme, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit, UnknownTapeType, ABI) job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) # We need to use primal as the key, to lookup the right method @@ -6045,7 +6071,13 @@ end compile_result = cached_compilation(job) - if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient + if !run_enzyme + ErrT = PrimalErrorThunk{typeof(compile_result.adjoint), FA, rt2, TT, width, ReturnPrimal, World} + return quote + Base.@_inline_meta + $ErrT($(compile_result.adjoint)) + end + elseif Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient TapeType = compile_result.TapeType AugT = AugmentedForwardThunk{typeof(compile_result.primal), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, ReturnPrimal, TapeType} AdjT = AdjointThunk{typeof(compile_result.adjoint), FA, rt2, Tuple{params.TT.parameters[2:end]...}, width, TapeType} @@ -6086,7 +6118,6 @@ import GPUCompiler: deferred_codegen_jobs params = EnzymeCompilerParams(Tuple{FA, TT.parameters...}, Mode, width, remove_innerty(A), true, #=abiwrap=#true, ModifiedBetween, ReturnPrimal, ShadowInit,ExpectedTapeType, FFIABI) tmp_job = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel=false), World) - sig = Tuple{eltype(FA), map(eltype, TT.parameters)...} interp = GPUCompiler.get_interpreter(tmp_job) rrt = something(Core.Compiler.typeinf_type(interp, mi.def, mi.specTypes, mi.sparam_vals), Any) diff --git a/test/runtests.jl b/test/runtests.jl index 182d8291df..f7132e1d75 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2602,6 +2602,15 @@ end @test 2.0 ≈ Enzyme.autodiff(Reverse, unionret, Active, Active(2.0), Duplicated(out, dout), Const(true))[1][1] end + +function assured_err(x) + throw(AssertionError("foo")) +end + +@testset "UnionAll" begin + @test_throws AssertionError Enzyme.autodiff(Reverse, assured_err, Active, Active(2.0)) +end + struct MyFlux end