Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow choice of AD backend for conditions #80

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ SimpleUnPack = "1.1"
julia = "1.6"

[extras]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Expand All @@ -52,4 +53,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "Documenter", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Pkg", "Random", "SparseArrays", "StaticArrays", "Test", "Zygote"]
test = ["AbstractDifferentiation", "Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "Documenter", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Pkg", "Random", "SparseArrays", "StaticArrays", "Test", "Zygote"]
2 changes: 1 addition & 1 deletion docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.9.2"
manifest_format = "2.0"
project_hash = "ff84ddc3d5227f964f2cd507ce5cbc83b4fba207"
project_hash = "01aa55c2eb0613f5724b3240c0f6da431aa9c124"

[[deps.AMD]]
deps = ["Libdl", "LinearAlgebra", "SparseArrays", "Test"]
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Expand Down
24 changes: 11 additions & 13 deletions docs/src/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@

## Supported autodiff backends

| Mode | Backend | Support |
| ------- | ---------------------------------------------------------- | ------- |
| Forward | [ForwardDiff.jl] | yes |
| Reverse | [ChainRules.jl]-compatible ([Zygote.jl], [ReverseDiff.jl]) | yes |
| Forward | [ChainRules.jl]-compatible ([Diffractor.jl]) | soon |
| Both | [Enzyme.jl] | someday |

[ForwardDiff.jl]: https://github.com/JuliaDiff/ForwardDiff.jl
[ChainRules.jl]: https://github.com/JuliaDiff/ChainRules.jl
[Zygote.jl]: https://github.com/FluxML/Zygote.jl
[ReverseDiff.jl]: https://github.com/JuliaDiff/ReverseDiff.jl
[Enzyme.jl]: https://github.com/EnzymeAD/Enzyme.jl
[Diffractor.jl]: https://github.com/JuliaDiff/Diffractor.jl
| Mode | Backend | Support |
| ------- | ---------------------------------------------------------------------- | ------- |
| Forward | [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | yes |
| Reverse | [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible | yes |
| Forward | [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible | soon |
| Both | [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | someday |

Note that these are the backends supported to differentiate an `ImplicitFunction`.
You choose one of them when you call e.g. `ForwardDiff.jacobian(implicit, x)` or `Zygote.jacobian(implicit, x)`.

To differentiate the `Conditions`, you can select any backend compatible with [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl), and specify it with the `conditions_backend` argument to the `ImplicitFunction` constructor.

## Writing conditions

Expand Down
4 changes: 3 additions & 1 deletion examples/0_intro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ We can even go higher-order by mixing the two packages (forward-over-reverse mod
The only technical requirement is to switch the linear solver to something that can handle dual numbers:
=#

implicit_higher_order = ImplicitFunction(forward, conditions, DirectLinearSolver())
implicit_higher_order = ImplicitFunction(
forward, conditions; linear_solver=DirectLinearSolver()
)

#=
Then the Jacobian itself is differentiable.
Expand Down
16 changes: 9 additions & 7 deletions ext/ImplicitDifferentiationChainRulesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ using LinearOperators: LinearOperator
using SimpleUnPack: @unpack

"""
rrule(rc, implicit, x; kwargs...)
rrule(rc, implicit, x, ReturnByproduct(); kwargs...)
rrule(rc, implicit, x[, ReturnByproduct()]; kwargs...)

Custom reverse rule for an [`ImplicitFunction`](@ref), to ensure compatibility with reverse mode autodiff.

Expand All @@ -34,14 +33,17 @@ function ChainRulesCore.rrule(
y, z = implicit(x, ReturnByproduct(); kwargs...)
n, m = length(x), length(y)

backend = ReverseRuleConfigBackend(rc)
if implicit.conditions.backend !== nothing
backend = implicit.conditions.backend
else
backend = ReverseRuleConfigBackend(rc)
end

pbA = pullback_function(backend, _y -> conditions(x, _y, z; kwargs...), y)
pbB = pullback_function(backend, _x -> conditions(_x, y, z; kwargs...), x)
pbmA = PullbackMul!(pbA, size(y))
pbmB = PullbackMul!(pbB, size(y))

Aᵀ_op = LinearOperator(R, m, m, false, false, pbmA)
Bᵀ_op = LinearOperator(R, n, m, false, false, pbmB)
Aᵀ_op = LinearOperator(R, m, m, false, false, PullbackMul!(pbA, size(y)))
Bᵀ_op = LinearOperator(R, n, m, false, false, PullbackMul!(pbB, size(y)))
Aᵀ_op_presolved = presolve(linear_solver, Aᵀ_op, y)

implicit_pullback = ImplicitPullback(Aᵀ_op_presolved, Bᵀ_op, linear_solver, x)
Expand Down
10 changes: 7 additions & 3 deletions ext/ImplicitDifferentiationForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ using LinearOperators: LinearOperator
using SimpleUnPack: @unpack

"""
implicit(x_and_dx::AbstractArray{<:Dual}; kwargs...)
implicit(x_and_dx::AbstractArray{<:Dual}, ReturnByproduct(); kwargs...)
implicit(x_and_dx::AbstractArray{<:Dual}[, ReturnByproduct()]; kwargs...)

Overload an [`ImplicitFunction`](@ref) on dual numbers to ensure compatibility with forward mode autodiff.

Expand All @@ -37,7 +36,12 @@ function (implicit::ImplicitFunction)(
y, z = implicit(x, ReturnByproduct(); kwargs...)
n, m = length(x), length(y)

backend = ForwardDiffBackend()
if implicit.conditions.backend !== nothing
backend = implicit.conditions.backend
else
backend = ForwardDiffBackend()
end

pfA = pushforward_function(backend, _y -> conditions(x, _y, z; kwargs...), y)
pfB = pushforward_function(backend, _x -> conditions(_x, y, z; kwargs...), x)

Expand Down
1 change: 1 addition & 0 deletions src/ImplicitDifferentiation.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module ImplicitDifferentiation

using AbstractDifferentiation: AbstractBackend
using Krylov: KrylovStats, gmres
using LinearOperators: LinearOperators, LinearOperator
using LinearAlgebra: lu, SingularException
Expand Down
18 changes: 12 additions & 6 deletions src/conditions.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
"""
Conditions{byproduct,C}
Conditions{byproduct,C,B}

Callable wrapper for conditions `c::C`, which ensures that a byproduct `z` is always accepted in addition to `x` and `y`.
Callable wrapper for conditions `c`, which ensures that a byproduct `z` is always accepted in addition to `x` and `y`.

The type parameter `byproduct` is a boolean stating whether or not `c` natively accepts `z`.

# Fields

- `c::C`: Callable returning an array that must be equal to zero.
- `backend::B`: Autodiff backend compatible with AbstractDifferentiation.jl, which will be used to differentiate the conditions. It defaults to `nothing`, which means the conditions will use the same backend as the implicit function they belong to.
"""
struct Conditions{byproduct,C}
struct Conditions{byproduct,C,B<:Union{Nothing,<:AbstractBackend}}
c::C
function Conditions{byproduct}(c::C) where {byproduct,C}
return new{byproduct,C}(c)
backend::B
function Conditions{byproduct}(c::C, backend::B=nothing) where {byproduct,C,B}
return new{byproduct,C,B}(c, backend)
end
end

function Base.show(io::IO, conditions::Conditions{byproduct}) where {byproduct}
return print(io, "Conditions{$byproduct}($(conditions.c))")
return print(io, "Conditions{$byproduct}($(conditions.c), $(conditions.backend))")
end

(conditions::Conditions{true})(x, y, z; kwargs...) = conditions.c(x, y, z; kwargs...)
Expand Down
43 changes: 23 additions & 20 deletions src/implicit_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ Differentiable wrapper for an implicit function defined by a forward mapping and
You can construct an `ImplicitFunction` from a forward mapping `f` and conditions `c`, both of which must be callables (function-like objects).
While `f` does not not need to be compatible with automatic differentiation, `c` has to be.

ImplicitFunction(f, c[, HandleByproduct()])
ImplicitFunction(f, c, linear_solver[, HandleByproduct()])
ImplicitFunction(
f, c[, HandleByproduct()];
linear_solver=IterativeLinearSolver(), conditions_backend=nothing
)

# Callable behavior

Expand All @@ -23,49 +25,50 @@ An `ImplicitFunction` object `implicit` behaves like a function, and every call
- If `HandleByproduct()` is passed as an argument to the constructor, we assume instead that the forward mapping is `x -> (y(x),z(x))` and the conditions are `c(x,y(x),z(x)) = 0`. In this case, `z(x)` can contain additional information generated by the forward mapping, but beware that we consider it constant for differentiation purposes.

Given `x ∈ ℝⁿ` and `y ∈ ℝᵈ`, we need as many conditions as output dimensions: `c(x,y,z) ∈ ℝᵈ`. We can then compute the Jacobian of `y(⋅)` using the implicit function theorem:
```
∂₂c(x,y(x),z(x)) * ∂y(x) = -∂₁c(x,y(x),z(x))
```

∂₂c(x,y(x),z(x)) * ∂y(x) = -∂₁c(x,y(x),z(x))

This requires solving a linear system `A * J = -B`, where `A ∈ ℝᵈˣᵈ`, `B ∈ ℝᵈˣⁿ` and `J ∈ ℝᵈˣⁿ`.

# Fields

- `forward::FF`: a wrapper of type [`Forward`](@ref) around the callable `f`
- `conditions::CC`: a wrapper of type [`Conditions`](@ref) around the callable `c`
- `conditions::CC`: a wrapper of type [`Conditions`](@ref) around the callable `c`, with tunable autodiff backend
- `linear_solver::LS`: an object subtyping [`AbstractLinearSolver`](@ref) (defaults to [`IterativeLinearSolver`](@ref)).
"""
struct ImplicitFunction{FF<:Forward,CC<:Conditions,LS<:AbstractLinearSolver}
forward::FF
conditions::CC
linear_solver::LS

function ImplicitFunction(f, c, linear_solver::AbstractLinearSolver)
function ImplicitFunction(
f,
c;
linear_solver::AbstractLinearSolver=IterativeLinearSolver(),
conditions_backend=nothing,
)
forward = Forward{false}(f)
conditions = Conditions{false}(c)
conditions = Conditions{false}(c, conditions_backend)
return new{typeof(forward),typeof(conditions),typeof(linear_solver)}(
forward, conditions, linear_solver
)
end

function ImplicitFunction(f, c, linear_solver::AbstractLinearSolver, ::HandleByproduct)
function ImplicitFunction(
f,
c,
::HandleByproduct;
linear_solver::AbstractLinearSolver=IterativeLinearSolver(),
conditions_backend=nothing,
)
forward = Forward{true}(f)
conditions = Conditions{true}(c)
conditions = Conditions{true}(c, conditions_backend)
return new{typeof(forward),typeof(conditions),typeof(linear_solver)}(
forward, conditions, linear_solver
)
end
end

function ImplicitFunction(f, c)
linear_solver = IterativeLinearSolver()
return ImplicitFunction(f, c, linear_solver)
end

function ImplicitFunction(f, c, ::HandleByproduct)
linear_solver = IterativeLinearSolver()
return ImplicitFunction(f, c, linear_solver, HandleByproduct())
end

function Base.show(io::IO, implicit::ImplicitFunction)
@unpack forward, conditions, linear_solver = implicit
return print(io, "ImplicitFunction($(forward.f), $(conditions.c), $linear_solver)")
Expand Down
19 changes: 15 additions & 4 deletions test/systematic.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using AbstractDifferentiation: ForwardDiffBackend, ZygoteBackend
using ChainRulesCore
using ChainRulesTestUtils
using ForwardDiff
Expand Down Expand Up @@ -47,17 +48,19 @@ function mysqrt_byproduct(x::AbstractArray)
return y, z
end

function make_implicit_sqrt(linear_solver)
function make_implicit_sqrt(linear_solver, conditions_backend=nothing)
forward(x) = mysqrt(x)
conditions(x, y) = y .^ 2 .- x
implicit = ImplicitFunction(forward, conditions, linear_solver)
implicit = ImplicitFunction(forward, conditions; linear_solver, conditions_backend)
return implicit
end

function make_implicit_sqrt_byproduct(linear_solver)
function make_implicit_sqrt_byproduct(linear_solver, conditions_backend=nothing)
forward(x) = mysqrt_byproduct(x)
conditions(x, y, z) = y .^ z .- x
implicit = ImplicitFunction(forward, conditions, linear_solver, HandleByproduct())
implicit = ImplicitFunction(
forward, conditions, HandleByproduct(); linear_solver, conditions_backend
)
return implicit
end

Expand Down Expand Up @@ -168,19 +171,27 @@ for linear_solver in linear_solver_candidates, x in x_candidates

testsetname = "$(typeof(x)) - $(typeof(linear_solver))"
implicit_sqrt = make_implicit_sqrt(linear_solver)
implicit_sqrt_forwarddiff = make_implicit_sqrt(linear_solver, ForwardDiffBackend())
implicit_sqrt_zygote = make_implicit_sqrt(linear_solver, ZygoteBackend())
implicit_sqrt_byproduct = make_implicit_sqrt_byproduct(linear_solver)

@testset verbose = true "$testsetname" begin
@testset "Call" begin
test_implicit_call(implicit_sqrt, x; y_true)
test_implicit_call(implicit_sqrt_forwarddiff, x; y_true)
test_implicit_call(implicit_sqrt_zygote, x; y_true)
test_implicit_call(implicit_sqrt_byproduct, x; y_true)
end
@testset "Forward" begin
test_implicit_forward(implicit_sqrt, x; y_true, J_true)
test_implicit_forward(implicit_sqrt_forwarddiff, x; y_true, J_true)
@test_skip test_implicit_forward(implicit_sqrt_zygote, x; y_true, J_true) # TODO: fix AD bug?
test_implicit_forward(implicit_sqrt_byproduct, x; y_true, J_true)
end
@testset "Reverse" begin
test_implicit_reverse(implicit_sqrt, x; y_true, J_true)
test_implicit_reverse(implicit_sqrt_forwarddiff, x; y_true, J_true)
@test_skip test_implicit_reverse(implicit_sqrt_zygote, x; y_true, J_true) # TODO: fix AD bug?
test_implicit_reverse(implicit_sqrt_byproduct, x; y_true, J_true)
end
end
Expand Down