Skip to content

Commit

Permalink
Jacobian returns one array (#351)
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy authored Jun 2, 2022
1 parent 52f0e86 commit d48e1f3
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 33 deletions.
22 changes: 15 additions & 7 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplic
export autodiff, jacobian, gradient, gradient!
export markType, batch_size, onehot, chunkedonehot

using LinearAlgebra

# Independent code, must be loaded before "compiler.jl"
include("pmap.jl")

Expand Down Expand Up @@ -703,20 +705,22 @@ grad = jacobian(Forward, f, [2.0, 3.0])
# output
([3.0, 0.0], [2.0, 1.0])
2×2 Matrix{Float64}:
3.0 2.0
0.0 1.0
```
"""
@inline function jacobian(::ForwardMode, args...; kwargs...)
gradient(Forward, args...; kwargs...)
cols = gradient(Forward, args...; kwargs...)
reduce(hcat, cols)
end

"""
jacobian(::ReverseMode, f, x, ::Val{num_outs}, ::Val{chunk})
Compute the jacobian of an array-input function `f` using (potentially vector)
reverse mode. The `chunk` argument denotes the chunk size to use and `num_outs`
denotes the number of outputs `f` will return in an array. Note that the result
of this is the transpose of the Forward [`jacobian`](@ref)
denotes the number of outputs `f` will return in an array.
Example:
Expand All @@ -727,7 +731,9 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2))
# output
([3.0, 2.0], [0.0, 1.0])
2×2 Matrix{Float64}:
3.0 2.0
0.0 1.0
```
"""
@inline function jacobian(::ReverseMode, f, x, n_outs::Val{n_out_val}, ::Val{chunk}) where {chunk, n_out_val}
Expand Down Expand Up @@ -763,15 +769,16 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2))
(i == num ? adjoint2 : adjoint)(BatchDuplicated(x, dx), tape)
return dx
end
tupleconcat(tmp...)
rows = tupleconcat(tmp...)
mapreduce(LinearAlgebra.adjoint, vcat, rows)
end

@inline function jacobian(::ReverseMode, f, x, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {n_out_val}
tt′ = Tuple{Duplicated{Core.Typeof(x)}}
tt = Tuple{Core.Typeof(x)}
rt = Core.Compiler.return_type(f, tt)
primal, adjoint = Enzyme.Compiler.thunk(f, #=df=#nothing, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), #=ModifiedBetween=#Val(false))
ntuple(n_outs) do i
rows = ntuple(n_outs) do i
Base.@_inline_meta
dx = zero(x)
res = primal(Duplicated(x, dx))
Expand All @@ -780,6 +787,7 @@ end
adjoint(Duplicated(x, dx), tape)
return dx
end
mapreduce(LinearAlgebra.adjoint, vcat, rows)
end


Expand Down
81 changes: 55 additions & 26 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -934,29 +934,32 @@ end
[v[2], v[1]*v[1], v[1]*v[1]*v[1]]
end

jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], #=n_outs=# Val(3), Val(1))
@test length(jac) == 3
@test jac[1] [ 0.0, 1.0]
@test jac[2] [ 4.0, 0.0]
@test jac[3] [12.0, 0.0]

jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], Val(1))
@test length(jac) == 2
@test jac[1] [ 0.0, 4.0, 12.0]
@test jac[2] [ 1.0, 0.0, 0.0]
jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], #=n_outs=# Val(3), Val(1))
@test size(jac) == (3, 2)
@test jac [ 0.0 1.0;
4.0 0.0;
12.0 0.0]

jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], Val(1))
@test size(jac) == (3, 2)
@test jac [ 0.0 1.0;
4.0 0.0;
12.0 0.0]

@test jac == Enzyme.jacobian(Forward, inout, [2.0, 3.0])
@test jac == ForwardDiff.jacobian(inout, [2.0, 3.0])

jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], #=n_outs=# Val(3), Val(2))
@test length(jac) == 3
@test jac[1] [ 0.0, 1.0]
@test jac[2] [ 4.0, 0.0]
@test jac[3] [12.0, 0.0]

jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], Val(2))
@test length(jac) == 2
@test jac[1] [ 0.0, 4.0, 12.0]
@test jac[2] [ 1.0, 0.0, 0.0]
jac = Enzyme.jacobian(Reverse, inout, [2.0, 3.0], #=n_outs=# Val(3), Val(2))
@test size(jac) == (3, 2)
@test jac [ 0.0 1.0;
4.0 0.0;
12.0 0.0]

jac = Enzyme.jacobian(Forward, inout, [2.0, 3.0], Val(2))
@test size(jac) == (3, 2)
@test jac [ 0.0 1.0;
4.0 0.0;
12.0 0.0]

function f_test_1(A, x)
u = A*x[2:end] .+ x[1]
Expand Down Expand Up @@ -984,12 +987,38 @@ end
x = ones(6)
A = Matrix{Float64}(LinearAlgebra.I, 5, 5)
u = Vector{Float64}(undef, 5)
@test J_r_1(A, x) == ([1.0, 1.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 1.0])
@test_broken J_r_2(A, x) == ([1.0, 1.0, 0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0, 0.0, 1.0])

# TODO fix forward vector bugs
# @test J_f_1(A, x) == (([1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0]),)
# @test J_f_2(A, x) == (([1.0, 1.0, 1.0, 1.0, 1.0], [1.0, 0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 1.0]),)

@test J_r_1(A, x) == [
1.0 1.0 0.0 0.0 0.0 0.0;
1.0 0.0 1.0 0.0 0.0 0.0;
1.0 0.0 0.0 1.0 0.0 0.0;
1.0 0.0 0.0 0.0 1.0 0.0;
1.0 0.0 0.0 0.0 0.0 1.0;
]

@test_broken J_r_2(A, x) == [
1.0 1.0 0.0 0.0 0.0 0.0;
1.0 0.0 1.0 0.0 0.0 0.0;
1.0 0.0 0.0 1.0 0.0 0.0;
1.0 0.0 0.0 0.0 1.0 0.0;
1.0 0.0 0.0 0.0 0.0 1.0;
]

# Function fails verification in test/CI
# @test J_f_1(A, x) == [
# 1.0 1.0 0.0 0.0 0.0 0.0;
# 1.0 0.0 1.0 0.0 0.0 0.0;
# 1.0 0.0 0.0 1.0 0.0 0.0;
# 1.0 0.0 0.0 0.0 1.0 0.0;
# 1.0 0.0 0.0 0.0 0.0 1.0;
# ]
# @test J_f_2(A, x) == [
# 1.0 1.0 0.0 0.0 0.0 0.0;
# 1.0 0.0 1.0 0.0 0.0 0.0;
# 1.0 0.0 0.0 1.0 0.0 0.0;
# 1.0 0.0 0.0 0.0 1.0 0.0;
# 1.0 0.0 0.0 0.0 0.0 1.0;
# ]

# Bug on (augmented) forward pass deducing if
# shadow value is used
Expand Down

2 comments on commit d48e1f3

@vchuravy
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/61547

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.0 -m "<description of version>" d48e1f3287a7f062f173a062b246477fe6a2fe3b
git push origin v0.10.0

Please sign in to comment.