Skip to content

Commit

Permalink
Soss: account for transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
marius311 committed May 6, 2022
1 parent 923cade commit 563db8d
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 20 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand Down
29 changes: 17 additions & 12 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ function ẑ_at_θ(prob::AbstractMuseProblem, x, z₀, θ; ∇z_logLike_atol)
end

function _check_optim_soln(soln)
Optim.converged(soln) || warn("MAP solution failed, result could be erroneous. Try tweaking `θ₀` or `∇z_logLike_atol` arguments to `muse` or fixing model.")
isfinite(soln.minimum) || error("MAP solution failed with logjoint(MAP)=$(soln.minimum).")
Optim.converged(soln) || @warn("MAP solution failed, result could be erroneous. Try tweaking `θ₀` or `∇z_logLike_atol` arguments to `muse` or fixing model.")
isfinite(soln.minimum) || @error("MAP solution failed with logjoint(MAP)=$(soln.minimum).")
end


Expand All @@ -174,35 +174,40 @@ end
θ;
fdm = central_fdm(3, 1),
atol = 1e-3,
rng = Random.default_rng()
rng = Random.default_rng(),
has_volume_factor = true
)
Checks the self-consistency of a defined problem at a given `θ`, e.g.
that `inv_transform_θ(prob, transform_θ(prob, θ)) ≈ θ`, etc... Mostly
useful as a diagonostic when implementing a new `AbstractMuseProblem`.
check that `inv_transform_θ(prob, transform_θ(prob, θ)) ≈ θ`, etc...
This is mostly useful as a diagonostic when implementing a new
`AbstractMuseProblem`.
A random `x` and `z` are sampled from `rng`. Finite differences are
computed using `fdm` and `atol` set the tolerance for `≈`.
computed using `fdm` and `atol` set the tolerance for `≈`.
`has_volume_factor` determines if the transformation includes the
logdet jacobian in the likelihood.
"""
function check_self_consistency(
prob,
θ;
fdm = central_fdm(3, 1),
atol = 1e-3,
rng = Random.default_rng()
rng = Random.default_rng(),
has_volume_factor = true
)

θ = standardizeθ(prob, θ)
x, z = sample_x_z(prob, rng, θ)
# volume factor which is added by transformations. dont assume the
# transformation is AD-able (eg it isnt for Turing)
J(θ) = FiniteDifferences.jacobian(fdm, θ -> transform_θ(prob, θ), θ)[1]
V(θ) = logdet(J(θ))
∇θ_V(θ) = FiniteDifferences.grad(fdm, V, θ)[1]
J(θ) = has_volume_factor ? FiniteDifferences.jacobian(fdm, θ -> transform_θ(prob, θ), θ)[1] : 1
V(θ) = has_volume_factor ? logdet(J(θ)) : 0
∇θ_V(θ) = has_volume_factor ? FiniteDifferences.grad(fdm, V, θ)[1] : 0
@testset "Self-consistency" begin
@test inv_transform_θ(prob, transform_θ(prob, θ)) θ atol=atol
@test logPriorθ(prob, θ, UnTransformedθ()) logPriorθ(prob, transform_θ(prob, θ), Transformedθ()) + V(θ) atol=atol
@test ∇θ_logLike(prob, x, z, θ, UnTransformedθ()) J(θ)' * ∇θ_logLike(prob, x, z, transform_θ(prob, θ), Transformedθ()) + ∇θ_V(θ) atol=atol
@test logPriorθ(prob, θ, UnTransformedθ()) logPriorθ(prob, transform_θ(prob, θ), Transformedθ()) .+ V(θ) atol=atol
@test ∇θ_logLike(prob, x, z, θ, UnTransformedθ()) J(θ)' * ∇θ_logLike(prob, x, z, transform_θ(prob, θ), Transformedθ()) .+ ∇θ_V(θ) atol=atol
end
end

33 changes: 25 additions & 8 deletions src/soss.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@


using .Soss
import .Soss.TransformVariables as TV

export SossMuseProblem

struct SossMuseProblem{A<:AD.AbstractBackend, M<:Soss.AbstractModel, MP<:Soss.AbstractModel} <: AbstractMuseProblem
autodiff :: A
model :: M
model_for_prior :: MP
xform_z
xform_θ
x
observed_vars
latent_vars
Expand All @@ -20,14 +23,19 @@ function SossMuseProblem(
autodiff = ForwardDiffBackend()
)
x = model.obs
sim = rand(model)
observed_vars = keys(x)
hyper_vars = params
latent_vars = keys(delete(rand(model), (observed_vars..., hyper_vars...)))
latent_vars = keys(delete(sim, (observed_vars..., hyper_vars...)))
model_for_prior = likelihood(Model(model), hyper_vars...)(argvals(model))
xform_z = xform(model | select(sim, hyper_vars))
xform_θ = xform(model | select(sim, latent_vars))
SossMuseProblem(
autodiff,
model,
model_for_prior,
xform_z,
xform_θ,
x,
observed_vars,
latent_vars,
Expand All @@ -36,29 +44,38 @@ function SossMuseProblem(
end

function transform_θ(prob::SossMuseProblem, θ)
θ # TODO
TV.inverse(prob.xform_θ, _namedtuple(θ))
end

function inv_transform_θ(prob::SossMuseProblem, θ)
θ # TODO
ComponentVector(TV.transform(prob.xform_θ, θ))
end

function logPriorθ(prob::SossMuseProblem, θ, θ_space)
function logPriorθ(prob::SossMuseProblem, θ::ComponentVector, ::UnTransformedθ)
logdensity(prob.model_for_prior(_namedtuple(θ)))
end
function logPriorθ(prob::SossMuseProblem, θ::AbstractVector, ::Transformedθ)
logPriorθ(prob, inv_transform_θ(prob, θ), UnTransformedθ())
end

function ∇θ_logLike(prob::SossMuseProblem, x, z, θ, θ_space)
first(AD.gradient(prob.autodiff, θ -> logdensity(prob.model | (;_namedtuple(x)..., _namedtuple(z)...), _namedtuple(θ)), θ))
function ∇θ_logLike(prob::SossMuseProblem, x, z::AbstractVector, θ::ComponentVector, ::UnTransformedθ)
like = prob.model | (;_namedtuple(x)..., TV.transform(prob.xform_z, z)...)
first(AD.gradient(prob.autodiff, θ -> logdensity(like, _namedtuple(θ)), θ))
end
function ∇θ_logLike(prob::SossMuseProblem, x, z::AbstractVector, θ::AbstractVector, ::Transformedθ)
like = prob.model | (;_namedtuple(x)..., TV.transform(prob.xform_z, z)...)
first(AD.gradient(prob.autodiff, θ -> logdensity(like, _namedtuple(inv_transform_θ(prob, θ))), θ))
end


function logLike_and_∇z_logLike(prob::SossMuseProblem, x, z, θ)
first.(AD.value_and_gradient(prob.autodiff, z -> logdensity(prob.model | (;_namedtuple(x)..., _namedtuple(θ)...), _namedtuple(z)), z))
first.(AD.value_and_gradient(prob.autodiff, z -> logdensity(prob.model | (;_namedtuple(x)..., _namedtuple(θ)...), TV.transform(prob.xform_z, z)), z))
end

function sample_x_z(prob::SossMuseProblem, rng::AbstractRNG, θ)
sim = predict(rng, prob.model, _namedtuple(θ))
x = ComponentVector(select(sim, prob.observed_vars))
z = ComponentVector(select(sim, prob.latent_vars))
z = TV.inverse(prob.xform_z, select(sim, prob.latent_vars))
(;x, z)
end

Expand Down

0 comments on commit 563db8d

Please sign in to comment.