Skip to content

Commit

Permalink
Replace AbstractDifferentiation with DifferentiationInterface
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jul 30, 2024
1 parent 67dac79 commit dbed1ad
Show file tree
Hide file tree
Showing 12 changed files with 85 additions and 79 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "MuseInference"
uuid = "43b88160-90c7-4f71-933b-9d65205cd921"
authors = ["Marius Millea <[email protected]>"]
version = "0.2.4"
version = "0.3.0"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
CovarianceEstimation = "587fd27a-f159-11e8-2dae-1979310e6154"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Expand All @@ -32,7 +33,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
AbstractDifferentiation = "0.5"
ComponentArrays = "0.12.3, 0.13, 0.14, 0.15"
CovarianceEstimation = "0.2.7"
Distributions = "0.25.36"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ First, load up the packages we'll need:

```@example 1
using MuseInference, Turing
using AbstractDifferentiation, Dates, LinearAlgebra, Printf, Plots, Random, Zygote
using ADTypes, Dates, LinearAlgebra, Printf, Plots, Random, Zygote
Turing.setadbackend(:zygote)
using Logging # hide
Logging.disable_logging(Logging.Info) # hide
Expand Down Expand Up @@ -164,7 +164,7 @@ prob = SimpleMuseProblem(
function logPrior(θ)
-θ^2/(2*3^2)
end;
autodiff = AbstractDifferentiation.ZygoteBackend()
autodiff = ADTypes.AutoZygote()
)
nothing # hide
```
Expand Down
5 changes: 3 additions & 2 deletions src/MuseInference.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module MuseInference

import AbstractDifferentiation as AD
using ADTypes: ADTypes
import DifferentiationInterface as DI
using Base.Iterators: repeated
using ComponentArrays
using CovarianceEstimation
Expand Down Expand Up @@ -56,4 +57,4 @@ end
end
end

end
end
12 changes: 3 additions & 9 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@


# some convenient type-piracy for scalars
AD.gradient(ad::AD.AbstractBackend, f, x::Real) = AD.derivative(ad, f, x)
AD.hessian(ad::AD.AbstractBackend, f, x::Real) = first.(AD.hessian(ad, ffirst, [x]))

function optim_only_fg!(func, autodiff)
function optim_only_fg!(func, backend::ADTypes.AbstractADType)
Optim.only_fg!() do F, G, z
if G != nothing
f, g = AD.value_and_gradient(autodiff, func, z)
f, g = DI.value_and_gradient(func, backend, z)
G .= first(g)
return f
end
if F != nothing
return func(z)
end
end
end
end
1 change: 0 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,4 +228,3 @@ function check_self_consistency(
@test ∇θ_logLike(prob, x, z, θ, UnTransformedθ()) J(θ)' * ∇θ_logLike(prob, x, z, transform_θ(prob, θ), Transformedθ()) .+ ∇θ_V(θ) atol=atol
end
end

55 changes: 33 additions & 22 deletions src/muse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ function muse!(
ẑs = getindex.(gẑs, :ẑ)

g_like′ = g_like_dat′ .- mean(g_like_sims′)
g_prior′ = AD.gradient(AD.ForwardDiffBackend(), θ′ -> logPriorθ(prob, θ′, Transformedθ()), θ′)[1]
g_prior′ = DI.gradient(ADTypes.AutoForwardDiff(), θ′) do θ′
logPriorθ(prob, θ′, Transformedθ())
end
g_post′ = g_like′ .+ g_prior′

# Jacobian
Expand All @@ -204,7 +206,9 @@ function muse!(
end
end

H_prior′ = AD.hessian(AD.ForwardDiffBackend(), θ′ -> logPriorθ(prob, θ′, Transformedθ()), θ′)[1]
H_prior′ = DI.hessian(ADTypes.AutoForwardDiff(), θ′) do θ′
logPriorθ(prob, θ′, Transformedθ())
end
H⁻¹_post′ = inv(inv(H⁻¹_like′) + H_prior′)

t = now() - t₀
Expand Down Expand Up @@ -288,8 +292,8 @@ Keyword arguments:
differentiation, rather than finite differences. Will require 2nd
order AD through your `logLike` so pay close attention to your
`prob.autodiff`. Either
`AD.HigherOrderBackend((AD.ForwardDiffBackend(),
AD.ZygoteBackend()))` or `AD.ForwardDiffBackend()` are recommended
`DifferentiationInterface.SecondOrder(ADTypes.AutoForwardDiff(),
ADTypes.AutoZygote())` or `ADTypes.AutoForwardDiff()` are recommended
(default: `false`)
"""
Expand Down Expand Up @@ -347,35 +351,40 @@ function get_H!(
end
T = eltype(z_start)

ad_fwd, ad_rev = AD.second_lowest(prob.autodiff), AD.lowest(prob.autodiff)
ad_fwd, ad_rev = if prob.autodiff isa DI.SecondOrder
# assume forward-over-reverse is provided
DI.outer(prob.autodiff), DI.inner(prob.autodiff)
else
prob.autodiff, prob.autodiff
end

## non-implicit-diff term
H1 = implicit_diff_H1_is_zero ? 𝟘 : copyto!(similar(𝟘), first(AD.jacobian(θ₀, backend=ad_fwd) do θ
H1 = implicit_diff_H1_is_zero ? 𝟘 : copyto!(similar(𝟘), DI.jacobian(ad_fwd, θ₀) do θ
local x, = sample_x_z(prob, copy(rng), θ)
first(AD.gradient(θ₀, backend=ad_rev) do θ′
DI.gradient(ad_rev, θ₀) do θ′
logLike(prob, x, ẑ, θ′, UnTransformedθ())
end)
end))
end
end)

## term involving dzMAP/dθ via implicit-diff (w/ conjugate-gradient linear solve)
dFdθ = first(AD.jacobian(θ₀, backend=ad_fwd) do θ
first(AD.gradient(ẑ, backend=ad_rev) do z
dFdθ = DI.jacobian(ad_fwd, θ₀) do θ
DI.gradient(ad_rev, ẑ) do z
logLike(prob, x, z, θ, UnTransformedθ())
end)
end)
dFdθ1 = first(AD.jacobian(θ₀, backend=ad_fwd) do θ
end
end
dFdθ1 = DI.jacobian(ad_fwd, θ₀) do θ
local x, = sample_x_z(prob, copy(rng), θ)
first(AD.gradient(ẑ, backend=ad_rev) do z
DI.gradient(ad_rev, ẑ) do z
logLike(prob, x, z, θ₀, UnTransformedθ())
end)
end)
end
end
# A is the operation of the Hessian of logLike w.r.t. z
A = LinearMap{T}(length(z_start), isposdef=true, issymmetric=true, ishermitian=true) do w
first(AD.jacobian(0, backend=ad_fwd) do α
first(AD.gradient(ẑ + α * w, backend=ad_rev) do z
DI.jacobian(ad_fwd, 0) do α
DI.gradient(ad_rev, + α * w) do z
logLike(prob, x, z, θ₀, UnTransformedθ())
end)
end)
end
end
end
A⁻¹_dFdθ1 = pmap(pool_jac, eachcol(dFdθ1)) do w
A⁻¹_w = cg(A, w; implicit_diff_cg_kwargs..., log=true)
Expand Down Expand Up @@ -536,7 +545,9 @@ function finalize_result!(result::MuseResult, prob::AbstractMuseProblem)
@unpack H, J, θ = result
if H != nothing && J != nothing && θ != nothing
𝟘 = zero(J) # if θ::ComponentArray, helps keep component labels
H_prior = -AD.hessian(AD.ForwardDiffBackend(), θ -> logPriorθ(prob, θ, UnTransformedθ()), result.θ)[1]
H_prior = -DI.hessian(ADTypes.AutoForwardDiff(), result.θ) do θ
logPriorθ(prob, θ, UnTransformedθ())
end
result.Σ⁻¹ = H' * inv(J) * H + H_prior + 𝟘
result.Σ = inv(result.Σ⁻¹) + 𝟘
if length(result.θ) == 1
Expand Down
14 changes: 7 additions & 7 deletions src/simple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct SimpleMuseProblem{X,S,L,Gθ,Pθ,GZ,A} <: AbstractMuseProblem
end

@doc doc"""
SimpleMuseProblem(x, sample_x_z, logLike, logPriorθ=(θ->0); ad=AD.ForwardDiffBackend())
SimpleMuseProblem(x, sample_x_z, logLike, logPriorθ=(θ->0); ad=ADTypes.AutoForwardDiff())
Specify a MUSE problem by providing the simulation and posterior
evaluation code by-hand. The argument `x` should be the observed data.
Expand Down Expand Up @@ -44,8 +44,8 @@ end
and should return the prior $\log\mathcal{P}(\theta)$ for your
problem. The `autodiff` parameter should be either
`MuseInference.ForwardDiffBackend()` or
`MuseInference.ZygoteBackend()`, specifying which library to use for
`ADTypes.AutoForwardDiff()` or
`ADTypes.AutoZygote()`, specifying which library to use for
automatic differenation through `logLike`.
Expand All @@ -69,20 +69,20 @@ prob = SimpleMuseProblem(
function logPrior(θ)
-θ^2/(2*3^2)
end;
autodiff = MuseInference.ZygoteBackend()
autodiff = ADTypes.AutoZygote()
)
# get solution
muse(prob, (θ=1,))
```
"""
function SimpleMuseProblem(x, sample_x_z, logLike, logPriorθ=->0); autodiff::AD.AbstractBackend=AD.ForwardDiffBackend())
function SimpleMuseProblem(x, sample_x_z, logLike, logPriorθ=->0); autodiff::ADTypes.AbstractADType=ADTypes.AutoForwardDiff())
SimpleMuseProblem(
x,
sample_x_z,
logLike,
(x,z,θ) -> first(AD.gradient(autodiff, θ -> logLike(x,z,θ), θ)),
(x,z,θ) -> first.(AD.value_and_gradient(autodiff, z -> logLike(x,z,θ), z)),
(x,z,θ) -> DI.gradient-> logLike(x,z,θ), autodiff, θ),
(x,z,θ) -> DI.value_and_gradient(z -> logLike(x,z,θ), autodiff, z),
logPriorθ,
autodiff
)
Expand Down
22 changes: 13 additions & 9 deletions src/soss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import .Soss.SimpleGraphs as SG

export SossMuseProblem

struct SossMuseProblem{A<:AD.AbstractBackend, M<:Soss.AbstractModel, MP<:Soss.AbstractModel} <: AbstractMuseProblem
struct SossMuseProblem{A<:ADTypes.AbstractADType, M<:Soss.AbstractModel, MP<:Soss.AbstractModel} <: AbstractMuseProblem
autodiff :: A
model :: M
model_for_prior :: MP
Expand All @@ -18,7 +18,7 @@ struct SossMuseProblem{A<:AD.AbstractBackend, M<:Soss.AbstractModel, MP<:Soss.Ab
end

@doc doc"""
SossMuseProblem(model; params, autodiff = ForwardDiffBackend())
SossMuseProblem(model; params, autodiff = AutoForwardDiff())
Specify a MUSE problem with a
[Soss](https://github.com/cscherrer/Soss.jl) model.
Expand All @@ -31,8 +31,8 @@ as a list of symbols. All other non-conditioned and non-`params`
variables will be considered the latent space.
The `autodiff` parameter should be either
`MuseInference.ForwardDiffBackend()` or
`MuseInference.ZygoteBackend()`, specifying which library to use for
`ADTypes.AutoForwardDiff()` or
`ADTypes.AutoZygote()`, specifying which library to use for
automatic differenation.
## Example
Expand Down Expand Up @@ -64,7 +64,7 @@ result = muse(prob, (θ=0,))
function SossMuseProblem(
model::Soss.ConditionalModel;
params = leaf_params(model),
autodiff = ForwardDiffBackend()
autodiff = ADTypes.AutoForwardDiff()
)
x = model.obs
!isempty(x) || error("Model must be conditioned on observed data.")
Expand Down Expand Up @@ -110,16 +110,20 @@ end

function ∇θ_logLike(prob::SossMuseProblem, x, z::AbstractVector, θ::ComponentVector, ::UnTransformedθ)
like = prob.model | (;x..., TV.transform(prob.xform_z, z)...)
first(AD.gradient(prob.autodiff, θ -> Soss.logdensityof(like, _namedtuple(θ)), θ))
DI.gradient-> Soss.logdensityof(like, _namedtuple(θ)), prob.autodiff, θ)
end
function ∇θ_logLike(prob::SossMuseProblem, x, z::AbstractVector, θ::AbstractVector, ::Transformedθ)
like = prob.model | (;x..., TV.transform(prob.xform_z, z)...)
first(AD.gradient(prob.autodiff, θ -> Soss.logdensityof(like, _namedtuple(inv_transform_θ(prob, θ))), θ))
DI.gradient(prob.autodiff, θ) do θ
Soss.logdensityof(like, _namedtuple(inv_transform_θ(prob, θ)))
end
end


function logLike_and_∇z_logLike(prob::SossMuseProblem, x, z, θ)
first.(AD.value_and_gradient(prob.autodiff, z -> Soss.logdensityof(prob.model | (;x..., _namedtuple(θ)...), TV.transform(prob.xform_z, z)), z))
DI.value_and_gradient(prob.autodiff, z) do z
Soss.logdensityof(prob.model | (;x..., _namedtuple(θ)...), TV.transform(prob.xform_z, z))
end
end

function sample_x_z(prob::SossMuseProblem, rng::AbstractRNG, θ)
Expand Down Expand Up @@ -150,4 +154,4 @@ function get_J!(result::MuseResult, model::Soss.ConditionalModel, θ₀ = result
end
function get_H!(result::MuseResult, model::Soss.ConditionalModel, θ₀ = result.θ; kwargs...)
get_H!(result, SossMuseProblem(model, params=_params_from_θ₀(θ₀)), θ₀; kwargs...)
end
end
16 changes: 9 additions & 7 deletions src/turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function DynPPL.maybe_invlink_before_eval!!(vi::DynPPL.SimpleVarInfo{NT,T,<:Part
end


struct TuringMuseProblem{A<:AD.AbstractBackend, M<:Turing.Model} <: AbstractMuseProblem
struct TuringMuseProblem{A<:ADTypes.AbstractADType, M<:Turing.Model} <: AbstractMuseProblem

autodiff :: A
model :: M
Expand Down Expand Up @@ -51,8 +51,8 @@ as a list of symbols. All other non-conditioned and non-`params`
variables will be considered the latent space.
The `autodiff` parameter should be either
`MuseInference.ForwardDiffBackend()` or
`MuseInference.ZygoteBackend()`, specifying which library to use for
`ADTypes.AutoForwardDiff()` or
`ADTypes.AutoZygote()`, specifying which library to use for
automatic differenation. The default uses whatever the global
`Turing.ADBACKEND` is currently set to.
Expand Down Expand Up @@ -118,14 +118,14 @@ function TuringMuseProblem(
# set backend based on Turing's by default
if autodiff == nothing
if Turing.ADBACKEND[] == :zygote
autodiff = AD.ZygoteBackend()
autodiff = ADTypes.AutoZygote()
elseif Turing.ADBACKEND[] == :forwarddiff
autodiff = AD.ForwardDiffBackend()
autodiff = ADTypes.AutoForwardDiff()
else
error("Unsupposed backend from Turing: $(Turing.ADBACKEND)")
end
end
if (Threads.nthreads() > 1) && hasmethod(AD.ZygoteBackend,Tuple{}) && (autodiff isa typeof(AD.ZygoteBackend()))
if (Threads.nthreads() > 1) && (autodiff isa typeof(ADTypes.AutoZygote()))
error("Turing doesn't support using the Zygote backend when Threads.nthreads()>1. Use a different backend or a single-thread.")
end

Expand Down Expand Up @@ -202,7 +202,9 @@ function logPriorθ(prob::TuringMuseProblem, θ, θ_space)
end

function ∇θ_logLike(prob::TuringMuseProblem, x, z, θ, θ_space)
first(AD.gradient(prob.autodiff, θ -> logLike(prob, x, z, θ, θ_space), θ))
DI.gradient(prob.autodiff, θ) do θ
logLike(prob, x, z, θ, θ_space)
end
end

function ẑ_at_θ(prob::TuringMuseProblem, x, z₀, θ; ∇z_logLike_atol)
Expand Down
7 changes: 1 addition & 6 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,6 @@ function Random.randn!(rng::AbstractRNG, A::Array{<:ForwardDiff.Dual})
A .= randn!(rng, ForwardDiff.value.(A))
end

# type-piracy bc these make code much clearer to read. could be removed if
# https://github.com/JuliaDiff/AbstractDifferentiation.jl/pull/62 is merged
AD.gradient(f, args...; backend::AD.AbstractBackend) = AD.gradient(backend, f, args...)
AD.jacobian(f, args...; backend::AD.AbstractBackend) = AD.jacobian(backend, f, args...)

# worker pool which just falls back to map
struct LocalWorkerPool <: AbstractWorkerPool end
Distributed.pmap(f, ::LocalWorkerPool, args...) = map(f, args...)
Expand All @@ -95,4 +90,4 @@ versionof(pkg::Module) = Pkg.dependencies()[Base.PkgId(pkg).uuid].version

# allow using InverseMap as an IterativeSolvers preconditioner
LinearAlgebra.ldiv!(dst::AbstractVector, A::InverseMap, src::AbstractVector) = mul!(dst, A.A, src)
LinearAlgebra.ldiv!(A::InverseMap, vec::AbstractVector) = copyto!(vec, mul!(A.A, vec))
LinearAlgebra.ldiv!(A::InverseMap, vec::AbstractVector) = copyto!(vec, mul!(A.A, vec))
6 changes: 3 additions & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -11,8 +11,8 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
MeasureTheory = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Soss = "8ce77f84-9b61-11e8-39ff-d17a774bf41c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -23,4 +23,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Soss = "0.21.2"
Turing = "0.28"
Turing = "0.28"
Loading

0 comments on commit dbed1ad

Please sign in to comment.