diff --git a/src/Enzyme.jl b/src/Enzyme.jl index c33b3a16d1..fb2eea7466 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1333,9 +1333,10 @@ end end """ - jacobian(::ReverseMode, f, x, ::Val{num_outs}, ::Val{chunk}) + jacobian(::ReverseMode, f, x, ::Val{num_outs}, ::Val{chunk}=Val(1)) + jacobian(::ReverseMode, f, x) -Compute the jacobian of an array-input function `f` using (potentially vector) +Compute the jacobian of an array-output function `f` using (potentially vector) reverse mode. The `chunk` argument denotes the chunk size to use and `num_outs` denotes the number of outputs `f` will return in an array. @@ -1493,6 +1494,23 @@ end end end +@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, n_out_val,RABI<:ABI, ErrIfFuncWritten} + res = f(x) + jac = if res isa AbstractArray + jacobian(ReverseMode{false,RABI, ErrIfFuncWritten}(), f, x, Val(length(jac))) + elseif res isa AbstractFloat + gradient(ReverseMode{false,RABI, ErrIfFuncWritten}(), f, x) + else + throw(AssertionError("Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))")) + end + + if ReturnPrimal + (res, jac) + else + jac + end +end + """ hvp(f::F, x::X, v::X) where {F, X}