Skip to content

Commit

Permalink
ease jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Aug 8, 2024
1 parent 14a292a commit bbed112
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1333,9 +1333,10 @@ end
end

"""
jacobian(::ReverseMode, f, x, ::Val{num_outs}, ::Val{chunk})
jacobian(::ReverseMode, f, x, ::Val{num_outs}, ::Val{chunk}=Val(1))
jacobian(::ReverseMode, f, x)
Compute the jacobian of an array-input function `f` using (potentially vector)
Compute the jacobian of an array-output function `f` using (potentially vector)
reverse mode. The `chunk` argument denotes the chunk size to use and `num_outs`
denotes the number of outputs `f` will return in an array.
Expand Down Expand Up @@ -1493,6 +1494,23 @@ end
end
end

@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X) where {ReturnPrimal, F, X, n_out_val,RABI<:ABI, ErrIfFuncWritten}
res = f(x)
jac = if res isa AbstractArray
jacobian(ReverseMode{false,RABI, ErrIfFuncWritten}(), f, x, Val(length(jac)))
elseif res isa AbstractFloat
gradient(ReverseMode{false,RABI, ErrIfFuncWritten}(), f, x)
else
throw(AssertionError("Unsupported return type of function for reverse-mode jacobian, $(Core.Typeof(res))"))
end

if ReturnPrimal
(res, jac)
else
jac
end
end

"""
hvp(f::F, x::X, v::X) where {F, X}
Expand Down

0 comments on commit bbed112

Please sign in to comment.