Skip to content

Commit

Permalink
Fix docstring for value_and_pullback_function (#125)
Browse files Browse the repository at this point in the history
* Fix docstring for value_and_pullback_function

* Remove comma

* Change names to pff and pbf
  • Loading branch information
gdalle authored Jan 11, 2024
1 parent 19e7d88 commit afec712
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 32 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AbstractDifferentiation"
uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
authors = ["Mohamed Tarek <[email protected]> and contributors"]
version = "0.6.0"
version = "0.6.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
69 changes: 38 additions & 31 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,50 +188,54 @@ end
"""
AD.pushforward_function(ab::AD.AbstractBackend, f, xs...)
Return the pushforward function `pf` of the function `f` at the inputs `xs` using backend `ab`.
Return the pushforward function `pff` of the function `f` at the inputs `xs` using backend `ab`.
The pushfoward function `pf` accepts as input a `Tuple` of tangents, one for each element in `xs`.
If `xs` consists of a single element, `pf` can also accept a single tangent instead of a 1-tuple.
The pushfoward function `pff` accepts as input a `Tuple` of tangents, one for each element in `xs`.
If `xs` consists of a single element, `pff` can also accept a single tangent instead of a 1-tuple.
"""
function pushforward_function(ab::AbstractBackend, f, xs...)
return (ds) -> begin
return jacobian(
lowest(ab),
(xds...,) -> begin
if ds isa Tuple
@assert length(xs) == length(ds)
newxs = xs .+ ds .* xds
return f(newxs...)
else
newx = only(xs) + ds * only(xds)
return f(newx)
end
end,
_zero.(xs, ds)...,
)
function pff(ds)
function pff_aux(xds...)
if ds isa Tuple
@assert length(xs) == length(ds)
newxs = xs .+ ds .* xds
return f(newxs...)
else
newx = only(xs) + ds * only(xds)
return f(newx)
end
end
return jacobian(lowest(ab), pff_aux, _zero.(xs, ds)...)
end
return pff
end

"""
AD.value_and_pushforward_function(ab::AD.AbstractBackend, f, xs...)
Return a function that, given tangents `ts`, computes the tuple `(v, p)` of the function value `v = f(xs...)` and the output `p` of the pushforward function `AD.pushforward_function(ab, f, xs...)` applied to `ts`.
Return a single function `vpff` which, given tangents `ts`, computes the tuple `(v, p) = vpff(ts)` composed of
- the function value `v = f(xs...)`
- the pushforward value `p = pff(ts)` given by the pushforward function `pff = AD.pushforward_function(ab, f, xs...)` applied to `ts`.
See also [`AbstractDifferentiation.pushforward_function`](@ref).
!!! warning
This name should be understood as "(value and pushforward) function", and thus is not aligned with the reverse mode counterpart [`AbstractDifferentiation.value_and_pullback_function`](@ref).
"""
function value_and_pushforward_function(ab::AbstractBackend, f, xs...)
n = length(xs)
value = f(xs...)
pf_function = pushforward_function(lowest(ab), f, xs...)
pff = pushforward_function(lowest(ab), f, xs...)

return ds -> begin
function vpff(ds)
if !(ds isa Tuple)
ds = (ds,)
end
@assert length(ds) == n
pf = pf_function(ds)
return value, pf
return value, pff(ds)
end
return vpff
end

_zero(::Number, d::Number) = zero(d)
Expand All @@ -253,10 +257,10 @@ end
"""
AD.pullback_function(ab::AD.AbstractBackend, f, xs...)
Return the pullback function `pb` of the function `f` at the inputs `xs` using backend `ab`.
Return the pullback function `pbf` of the function `f` at the inputs `xs` using backend `ab`.
The pullback function `pb` accepts as input a `Tuple` of cotangents, one for each output of `f`.
If `f` has a single output, `pb` can also accept a single input instead of a 1-tuple.
The pullback function `pbf` accepts as input a `Tuple` of cotangents, one for each output of `f`.
If `f` has a single output, `pbf` can also accept a single input instead of a 1-tuple.
"""
function pullback_function(ab::AbstractBackend, f, xs...)
_, pbf = value_and_pullback_function(ab, f, xs...)
Expand All @@ -266,14 +270,17 @@ end
"""
AD.value_and_pullback_function(ab::AD.AbstractBackend, f, xs...)
Return a function that, given cotangents `ts`, computes the tuple `(v, p)` of the function value `v = f(xs...)` and the output `p` of the pullback function `AD.pullback_function(ab, f, xs...)` applied to `ts`.
Return a tuple `(v, pbf)` of the function value `v = f(xs...)` and the pullback function `pbf = AD.pullback_function(ab, f, xs...)`.
See also [`AbstractDifferentiation.pullback_function`](@ref).
!!! warning
This name should be understood as "value and (pullback function)", and thus is not aligned with the forward mode counterpart [`AbstractDifferentiation.value_and_pushforward_function`](@ref).
"""
function value_and_pullback_function(ab::AbstractBackend, f, xs...)
value = f(xs...)
function pullback_function(ws)
function pullback_gradient_function(_xs...)
function pbf(ws)
function pbf_aux(_xs...)
vs = f(_xs...)
if ws isa Tuple
@assert length(vs) == length(ws)
Expand All @@ -282,9 +289,9 @@ function value_and_pullback_function(ab::AbstractBackend, f, xs...)
return _dot(vs, ws)
end
end
return gradient(lowest(ab), pullback_gradient_function, xs...)
return gradient(lowest(ab), pbf_aux, xs...)
end
return value, pullback_function
return value, pbf
end

struct LazyDerivative{B,F,X}
Expand Down

0 comments on commit afec712

Please sign in to comment.