Skip to content

Commit

Permalink
Customize backend for conditions (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Aug 5, 2023
1 parent 9d34d8b commit 450d8aa
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 25 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Expand All @@ -50,10 +51,11 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "Documenter", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Pkg", "Random", "SparseArrays", "StaticArrays", "Test", "Zygote"]
test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "Documenter", "FiniteDifferences", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Pkg", "Random", "ReverseDiff", "SparseArrays", "StaticArrays", "Test", "Zygote"]
13 changes: 12 additions & 1 deletion docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.9.2"
manifest_format = "2.0"
project_hash = "01aa55c2eb0613f5724b3240c0f6da431aa9c124"
project_hash = "7637ae79192e640fcba617dc693c773a3da65a26"

[[deps.AMD]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse_jll"]
Expand Down Expand Up @@ -285,6 +285,11 @@ weakdeps = ["StaticArrays"]
[deps.ForwardDiff.extensions]
ForwardDiffStaticArraysExt = "StaticArrays"

[[deps.FunctionWrappers]]
git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e"
uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
version = "1.1.3"

[[deps.Functors]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "9a68d75d466ccc1218d0552a8e1631151c569545"
Expand Down Expand Up @@ -595,6 +600,12 @@ git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.3.0"

[[deps.ReverseDiff]]
deps = ["ChainRulesCore", "DiffResults", "DiffRules", "ForwardDiff", "FunctionWrappers", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "SpecialFunctions", "StaticArrays", "Statistics"]
git-tree-sha1 = "18ed404e60753972ffd459c0db9725ff34d6e60d"
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
version = "1.15.0"

[[deps.Richardson]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "e03ca566bec93f8a3aeb059c8ef102f268a38949"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
10 changes: 8 additions & 2 deletions docs/src/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@

## Supported autodiff backends

To differentiate an `ImplicitFunction`, the following backends are supported.

| Backend | Forward mode | Reverse mode |
| ---------------------------------------------------------------------- | ------------ | ------------ |
| [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | yes | - |
| [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible | yes | soon |
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | someday | someday |

By default, the conditions are differentiated with the same backend as the `ImplicitFunction` that contains them.
However, this can be switched to any backend compatible with [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) (i.e. a subtype of `AD.AbstractBackend`).
You can specify it with the `conditions_backend` keyword argument when constructing an `ImplicitFunction`.

## Writing conditions

We recommend that the conditions themselves do not involve calls to autodiff, even when they describe a gradient.
Otherwise, you will need to make sure that nested autodiff works well in your case.
For instance, if you're differentiating your implicit function in reverse mode with Zygote.jl, you may want to use [`Zygote.forwarddiff`](https://fluxml.ai/Zygote.jl/stable/utils/#Zygote.forwarddiff) to wrap the conditions and differentiate them with ForwardDiff.jl instead.
For instance, if you're differentiating your implicit function (and your conditions) in reverse mode with Zygote.jl, you may want to use ForwardDiff.jl mode to compute gradients inside the conditions.

## Matrices and higher-order arrays

Expand Down Expand Up @@ -56,7 +62,7 @@ Keep in mind that derivatives of `z` will not be computed: the byproduct is cons

## Performance tips

If you work with small arrays (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) if you seek increased performance.
If you work with small arrays (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance.

## Modeling tips

Expand Down
16 changes: 14 additions & 2 deletions ext/ImplicitDifferentiationChainRulesExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module ImplicitDifferentiationChainRulesExt

using AbstractDifferentiation: ReverseRuleConfigBackend
using AbstractDifferentiation: AbstractBackend, ReverseRuleConfigBackend
using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, ZeroTangent, rrule, unthunk
using ImplicitDifferentiation: ImplicitFunction, reverse_operators, solve
using LinearAlgebra: lmul!, mul!
Expand All @@ -20,13 +20,25 @@ function ChainRulesCore.rrule(
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}; kwargs...
) where {R}
y_or_yz = implicit(x; kwargs...)
backend = ReverseRuleConfigBackend(rc)
backend = reverse_conditions_backend(rc, implicit)
Aᵀ_op, Bᵀ_op = reverse_operators(backend, implicit, x, y_or_yz; kwargs)
byproduct = y_or_yz isa Tuple
implicit_pullback = ImplicitPullback{byproduct}(Aᵀ_op, Bᵀ_op, implicit.linear_solver, x)
return y_or_yz, implicit_pullback
end

function reverse_conditions_backend(
rc::RuleConfig, ::ImplicitFunction{F,C,L,Nothing}
) where {F,C,L}
return ReverseRuleConfigBackend(rc)
end

function reverse_conditions_backend(
::RuleConfig, implicit::ImplicitFunction{F,C,L,<:AbstractBackend}
) where {F,C,L}
return implicit.conditions_backend
end

struct ImplicitPullback{byproduct,A,B,L,X}
Aᵀ_op::A
Bᵀ_op::B
Expand Down
14 changes: 12 additions & 2 deletions ext/ImplicitDifferentiationForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ else
using ..ForwardDiff: Dual, Partials, jacobian, partials, value
end

using AbstractDifferentiation: AbstractBackend, ForwardDiffBackend, pushforward_function
using AbstractDifferentiation: AbstractBackend, ForwardDiffBackend
using ImplicitDifferentiation: ImplicitFunction, DirectLinearSolver, IterativeLinearSolver
using ImplicitDifferentiation: forward_operators, solve, identity_break_autodiff
using LinearAlgebra: lmul!, mul!
Expand All @@ -29,7 +29,7 @@ function (implicit::ImplicitFunction)(
y_or_yz = implicit(x; kwargs...)
y = _output(y_or_yz)

backend = ForwardDiffBackend()
backend = forward_conditions_backend(implicit)
A_op, B_op = forward_operators(backend, implicit, x, y_or_yz; kwargs)

dy = ntuple(Val(N)) do k
Expand All @@ -55,6 +55,16 @@ function (implicit::ImplicitFunction)(
end
end

function forward_conditions_backend(::ImplicitFunction{F,C,L,Nothing}) where {F,C,L}
return ForwardDiffBackend()
end

function forward_conditions_backend(
implicit::ImplicitFunction{F,C,L,<:AbstractBackend}
) where {F,C,L}
return implicit.conditions_backend
end

_output(y::AbstractArray) = y
_output(yz::Tuple) = yz[1]
_byproduct(yz::Tuple) = yz[2]
Expand Down
42 changes: 28 additions & 14 deletions src/implicit_function.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
ImplicitFunction{F,C,L<:AbstractLinearSolver}
ImplicitFunction{F,C,L,B}
Differentiable wrapper for an implicit function defined by a forward mapping `y` and a set of conditions `c`.
Expand All @@ -12,15 +12,19 @@ This requires solving a linear system `A * J = -B`.
# Constructors
You can construct an `ImplicitFunction` from two callables (function-like objects) `forward` and `conditions`.
While `forward` does not not need to be compatible with automatic differentiation, `conditions` has to be.
ImplicitFunction(forward, conditions; linear_solver=IterativeLinearSolver())
There are two possible signatures for `forward` and `conditions`, which must be consistent with one another:
ImplicitFunction(
forward, conditions;
linear_solver=IterativeLinearSolver(), conditions_backend=nothing,
)
1. Standard: `forward(x; kwargs...) = y` and `conditions(x, y; kwargs...) = c`
2. Byproduct: `forward(x; kwargs...) = (y, z)` and `conditions(x, y, z; kwargs...) = c`.
While `forward` does not not need to be compatible with automatic differentiation, `conditions` has to be (with the provided `conditions_backend` if there is one).
There are two possible signatures for `forward` and `conditions`, which must be consistent with one another:
1. Standard: `forward(x; kwargs...) = y` and `conditions(x, y; kwargs...) = c`
2. Byproduct: `forward(x; kwargs...) = (y, z)` and `conditions(x, y, z; kwargs...) = c`.
In both cases, `x`, `y` and `c` must be arrays with `size(y) = size(c)`.
In the second case, the byproduct `z` can be an arbitrary object generated by `forward`, but beware that we consider it constant for differentiation purposes.
Expand All @@ -36,23 +40,33 @@ This returns exactly `implicit.forward(x; kwargs...)`, which as we mentioned can
- `forward::F`
- `conditions::C`
- `linear_solver::L`
- `linear_solver::L<:AbstractLinearSolver`
- `conditions_backend::B<:Union{Nothing,AbstractBackend}`
!!! warning "Warning"
At the moment, `conditions_backend` can only be `nothing` or `AD.ForwardDiffBackend()`. We are investigating why the other backends fail.
"""
struct ImplicitFunction{F,C,L<:AbstractLinearSolver}
struct ImplicitFunction{F,C,L<:AbstractLinearSolver,B<:Union{Nothing,AbstractBackend}}
forward::F
conditions::C
linear_solver::L
conditions_backend::B

function ImplicitFunction(
forward::F, conditions::C; linear_solver::L=IterativeLinearSolver()
) where {F,C,L}
return new{F,C,L}(forward, conditions, linear_solver)
forward::F,
conditions::C;
linear_solver::L=IterativeLinearSolver(),
conditions_backend::B=nothing,
) where {F,C,L,B}
return new{F,C,L,B}(forward, conditions, linear_solver, conditions_backend)
end
end

function Base.show(io::IO, implicit::ImplicitFunction)
@unpack forward, conditions, linear_solver = implicit
return print(io, "ImplicitFunction($forward, $conditions, $linear_solver)")
@unpack forward, conditions, linear_solver, conditions_backend = implicit
return print(
io, "ImplicitFunction($forward, $conditions, $linear_solver, $conditions_backend)"
)
end

function (implicit::ImplicitFunction)(x::AbstractArray; kwargs...)
Expand Down
15 changes: 12 additions & 3 deletions test/systematic.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import AbstractDifferentiation as AD
using ChainRulesCore
using ChainRulesTestUtils
using FiniteDifferences: FiniteDifferences
using ForwardDiff: ForwardDiff
import ImplicitDifferentiation as ID
using ImplicitDifferentiation: ImplicitFunction, identity_break_autodiff
using ImplicitDifferentiation: DirectLinearSolver, IterativeLinearSolver
using JET
using LinearAlgebra
using Random
using ReverseDiff: ReverseDiff
using StaticArrays
using Test
using Zygote: Zygote, ZygoteRuleConfig
Expand Down Expand Up @@ -237,12 +239,19 @@ x_candidates = (
);

linear_solver_candidates = (IterativeLinearSolver(), DirectLinearSolver())
conditions_backend_candidates = (nothing, AD.ForwardDiffBackend());
# conditions_backend_failing_candidates = (
# AD.ZygoteBackend(), AD.FiniteDifferencesBackend, AD.ReverseDiffBackend()()
# ) # TODO: understand why

for linear_solver in linear_solver_candidates,
conditions_backend in conditions_backend_candidates,
x in x_candidates

for linear_solver in linear_solver_candidates, x in x_candidates
x isa StaticArray && linear_solver isa IterativeLinearSolver && continue
testsetname = "$(typeof(linear_solver)) - $(typeof(x))"
testsetname = "$(typeof(linear_solver)) - $(typeof(conditions_backend)) - $(typeof(x))"
@info "$testsetname"
@testset "$testsetname" begin
test_implicit(x; linear_solver)
test_implicit(x; linear_solver, conditions_backend)
end
end

0 comments on commit 450d8aa

Please sign in to comment.