Skip to content

Commit

Permalink
Close issue #349 (#352)
Browse files Browse the repository at this point in the history
* WIP

* WIP

* Add extension for ForwardDiff

* Update ext/RootsForwardDiffExt.jl

Co-authored-by: David Widmann <[email protected]>

* Update ext/RootsForwardDiffExt.jl

Co-authored-by: David Widmann <[email protected]>

* Update ext/RootsForwardDiffExt.jl

Co-authored-by: David Widmann <[email protected]>

* Update ext/RootsForwardDiffExt.jl

Co-authored-by: David Widmann <[email protected]>

* cleanup

* add keyword parameter test

* only weakdep; thx!

* cleanup

---------

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
jverzani and devmotion authored Mar 7, 2023
1 parent c0cb2ba commit 57f0c33
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 10 deletions.
7 changes: 7 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@ CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
RootsForwardDiffExt = "ForwardDiff"

[compat]
ChainRulesCore = "1"
CommonSolve = "0.1, 0.2"
ForwardDiff = "0.10"
Setfield = "0.7, 0.8, 1"
julia = "1.0"

Expand Down
21 changes: 13 additions & 8 deletions docs/src/roots.md
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,12 @@ savefig("flight.svg"); nothing #hide
To maximize the range we solve for the lone critical point of `howfar`
within reasonable starting points.

The automatic differentiation provided by `ForwardDiff` will
work through a call to `find_zero` **if** the initial point has the proper type (depending on an expression of `theta` in this case).
As we use `200*cosd(theta)-5` for a starting point, this is satisfied.
As of version `v"1.9"`, the automatic differentiation provided by
`ForwardDiff` will bypass working through a call to `find_zero`. Prior
to this version, automatic differentiation will work *if* the
initial point has the proper type (depending on an expression of
`theta` in this case). As we use `200*cosd(theta)-5` for a starting
point, this is satisfied.

```jldoctest roots
julia> (tstar = find_zero(D(howfar), 45)) ≈ 26.2623089
Expand All @@ -532,7 +535,7 @@ In the last example, the question of how the distance varies with the angle is c

In general, for functions with parameters, ``f(x,p)``, derivatives with respect to the ``p`` variable(s) may be of interest.

A first attempt, may be to try and auto-differentiate the output of `find_zero`. For example:
A first attempt, as shown above, may be to try and auto-differentiate the output of `find_zero`. For example:

```@example roots
f(x, p) = x^2 - p # p a scalar
Expand All @@ -544,17 +547,19 @@ F(p) = find_zero(f, one(p), Order1(), p)
ForwardDiff.derivative(F, p)
```

There are issues with this approach, though here it finds the correct answer, as will be seen: a) it is not as performant as what we will discuss next, b) the subtle use of `one(p)` for the starting point is needed to ensure the type for the ``x`` values is correct, and c) not all algorithms will work, in particular `Bisection` is not amenable to this approach:
Prior to version `v"1.9"` of `Julia`,
there were issues with this approach, though in this case it finds the correct answer, as will be seen: a) it is not as performant as what we will discuss next, b) the subtle use of `one(p)` for the starting point is needed to ensure the type for the ``x`` values is correct, and c) not all algorithms will work, in particular `Bisection` is not amenable to this approach.

```@example roots
F(p) = find_zero(f, (zero(p), one(p)), Roots.Bisection(), p)
ForwardDiff.derivative(F, 1/2)
```

The `0.0` is the wrong answer, as the duals used by `ForwardDiff.derivative` do not flow through the `Bisection` algorithm.
This will be `0.0` if the differentiation is propagated through the algorithm.
With `v"1.9"` of `Julia` or later, the derivative is calculated correctly through the method described below.


Using the implicit function theorem and following these [notes](https://math.mit.edu/~stevenj/18.336/adjoint.pdf) or this [paper](https://arxiv.org/pdf/2105.15183.pdf) on the adjoint method, we can auto-differentiate without pushing that machinery through `find_zero`.
Using the implicit function theorem and following these [notes](https://math.mit.edu/~stevenj/18.336/adjoint.pdf), this [paper](https://arxiv.org/pdf/2105.15183.pdf) on the adjoint method, or the methods more generally applied in the [ImplicitDifferentiation](https://github.com/gdalle/ImplicitDifferentiation.jl) package we can auto-differentiate without pushing that machinery through `find_zero`.

The solution, ``x^*(p)``, provided by `find_zero` depends on the parameter(s), ``p``. Notationally,

Expand Down Expand Up @@ -616,7 +621,7 @@ fₚ = ForwardDiff.gradient(p -> f(xᵅ, p), p)
- fₚ / fₓ
```

The package provides a `ChainRulesCore.rrule` and `ChainRulesCore.frule` implementation that should allow automatic differentiation packages relying on `ChainRulesCore` (e.g., `Zygote`) to differentiate in `p` using the above approach. (Thanks to `@devmotion` for help here.)
The package provides a package extension to use `ForwardDiff` directly to find derivatives or gradients, as above, with version `v"1.9"` or later of `Julia`, and a `ChainRulesCore.rrule` and `ChainRulesCore.frule` implementation that should allow automatic differentiation packages relying on `ChainRulesCore` (e.g., `Zygote`) to differentiate in `p` using the same approach. (Thanks to `@devmotion` for much help here.)



Expand Down
26 changes: 26 additions & 0 deletions ext/RootsForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module RootsForwardDiffExt

using Roots
using ForwardDiff
import ForwardDiff: Dual, value, partials

# For ForwardDiff we add a `solve` method for Dual types
function Roots.solve(ZP::ZeroProblem,
M::Roots.AbstractUnivariateZeroMethod,
𝐩::Union{Dual{T},
AbstractArray{<:Dual{T,<:Real}}
};
kwargs...) where {T}
f = ZP.F
pᵥ = value.(𝐩)

xᵅ = solve(ZP, M, pᵥ; kwargs...)
𝐱ᵅ = Dual{T}(xᵅ, one(xᵅ))

fₓ = partials(f(𝐱ᵅ, pᵥ), 1)
fₚ = partials(f(xᵅ, 𝐩))

Dual{T}(xᵅ, - fₚ / fₓ)
end

end
5 changes: 3 additions & 2 deletions src/find_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,14 @@ function find_zero(
tracks::AbstractTracks=NullTracks(),
kwargs...,
)

xstar = solve(
ZeroProblem(f, x0),
M,
p′ === nothing ? p : p′;
verbose=verbose,
tracks=tracks,
kwargs...,
kwargs...
)

isnan(xstar) && throw(ConvergenceFailed("Algorithm failed to converge"))
Expand Down Expand Up @@ -305,7 +306,7 @@ end

function init(𝑭𝑿::ZeroProblem, p′=nothing; kwargs...)
M = length(𝑭𝑿.x₀) == 1 ? Order0() : AlefeldPotraShi()
init(𝑭𝑿, M; p=p′, kwargs...)
init(𝑭𝑿, M, p′; kwargs...)
end

function init(
Expand Down
19 changes: 19 additions & 0 deletions test/test_find_zero.jl
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,22 @@ end
@test_throws ArgumentError Roots._extrema((π, π))
@test_throws ArgumentError Roots._extrema([π, π])
end

@testset "senstivity" begin
# Issue #349
if VERSION >= v"1.9.0-"
f(x, p) = cos(x) - first(p)*x
x₀ = (0,pi/2)
F(p) = solve(ZeroProblem(f, x₀), Bisection(), p)
G(p) = find_zero(f, x₀, Bisection(), p)
H(p) = find_zero(f, x₀, Bisection(); p = p)

= -0.4416107917053284
@test ForwardDiff.derivative(F, 1.0) -0.4416107917053284
@test ForwardDiff.gradient(F, [1.0,2])[1] -0.4416107917053284
@test ForwardDiff.derivative(G, 1.0) -0.4416107917053284
@test ForwardDiff.gradient(G, [1.0,2])[1] -0.4416107917053284
@test ForwardDiff.derivative(H, 1.0) -0.4416107917053284
@test ForwardDiff.gradient(H, [1.0,2])[1] -0.4416107917053284
end
end

0 comments on commit 57f0c33

Please sign in to comment.