From 8c3aa44d9e3c9d1fec96b8db97672fa8c2b4ecb4 Mon Sep 17 00:00:00 2001 From: Thomas Wutzler Date: Mon, 30 Sep 2024 10:30:49 +0200 Subject: [PATCH] Pointpriors (#663) * implement pointwise_logpriors * implement varwise_logpriors * remove pointwise_logpriors * revert dot_assume to not explicitly resolve components of sum * docstring varwise_logpriores use loop for prior in example Unfortunately cannot make it a jldoctest, because relies on Turing for sampling * integrate pointwise_loglikelihoods and varwise_logpriors by pointwise_densities * record single prior components by forwarding dot_tilde_assume to tilde_assume * forward dot_tilde_assume to tilde_assume for Multivariate * avoid recording prior components on leaf-prior-context and avoid recording likelihoods when invoked with leaf-Likelihood context * undeprecate pointwise_loglikelihoods and implement pointwise_prior_logdensities mostly taken from #669 * drop vi instead of re-compute vi bgctw first forwared dot_tilde_assume to get a correct vi and then recomputed it for recording component prior densities. Replaced this by the Hack of torfjelde that completely drops vi and recombines the value, so that assume is called only once for each varName, * include docstrings of pointwise_logdensities pointwise_prior_logdensities int api.md docu * Update src/pointwise_logdensities.jl remove commented code Co-authored-by: Tor Erlend Fjelde * Update src/pointwise_logdensities.jl remove commented code Co-authored-by: Tor Erlend Fjelde * Update test/pointwise_logdensities.jl rename m to model Co-authored-by: Tor Erlend Fjelde * Update test/pointwise_logdensities.jl remove unused code Co-authored-by: Tor Erlend Fjelde * Update test/pointwise_logdensities.jl rename m to model Co-authored-by: Tor Erlend Fjelde * Update test/pointwise_logdensities.jl rename m to model Co-authored-by: Tor Erlend Fjelde * Update src/test_utils.jl remove old code Co-authored-by: Tor Erlend Fjelde * rename m to model * JuliaFormatter * Update test/runtests.jl remove interactive code Co-authored-by: Tor Erlend Fjelde * remove demo_dot_assume_matrix_dot_observe_matrix2 testcase testing higher dimensions better left for other PR * ignore local interactive development code * ignore temporary directory holding local interactive development code * Apply suggestions from code review: clean up comments and Imports Co-authored-by: Tor Erlend Fjelde * Apply suggestions from code review: change test of applying to chains on already used model Co-authored-by: Tor Erlend Fjelde * fix test on names in likelihood components to work with literal models * try to fix testset pointwise_logdensities chain * Update test/pointwise_logdensities.jl * Update .gitignore * Formtating * Fixed tests * Updated docs for `pointwise_logdensities` + made it a doctest not dependent on Turing.jl * Bump patch version * Remove blank line from `@model` in doctest to see if that fixes the parsing issues * Added doctest filter to handle the `;;]` at the end of lines for matrices --------- Co-authored-by: Tor Erlend Fjelde Co-authored-by: Tor Erlend Fjelde --- Project.toml | 2 +- docs/src/api.md | 6 +- src/DynamicPPL.jl | 4 +- src/loglikelihoods.jl | 257 --------------------- src/pointwise_logdensities.jl | 395 +++++++++++++++++++++++++++++++++ src/test_utils.jl | 40 ++++ test/contexts.jl | 4 +- test/loglikelihoods.jl | 24 -- test/pointwise_logdensities.jl | 101 +++++++++ test/runtests.jl | 7 +- 10 files changed, 553 insertions(+), 287 deletions(-) delete mode 100644 src/loglikelihoods.jl create mode 100644 src/pointwise_logdensities.jl delete mode 100644 test/loglikelihoods.jl create mode 100644 test/pointwise_logdensities.jl diff --git a/Project.toml b/Project.toml index 700f040c7..95ce8cde2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.29.1" +version = "0.29.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/api.md b/docs/src/api.md index 97c48316e..156b51e03 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -124,10 +124,14 @@ Return values of the model function for a collection of samples can be obtained generated_quantities ``` -For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). +For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using +[`pointwise_prior_logdensities`](@ref) or both, i.e. all variables, using +[`pointwise_logdensities`](@ref). ```@docs +pointwise_logdensities pointwise_loglikelihoods +pointwise_prior_logdensities ``` For converting a chain into a format that can more easily be fed into a `Model` again, for example using `condition`, you can use [`value_iterator_from_chain`](@ref). diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index eb027b45b..777c770d4 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -115,6 +115,8 @@ export AbstractVarInfo, # Convenience functions logprior, logjoint, + pointwise_prior_logdensities, + pointwise_logdensities, pointwise_loglikelihoods, condition, decondition, @@ -181,7 +183,7 @@ include("varinfo.jl") include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") -include("loglikelihoods.jl") +include("pointwise_logdensities.jl") include("submodel_macro.jl") include("test_utils.jl") include("transforming.jl") diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl deleted file mode 100644 index 227a70889..000000000 --- a/src/loglikelihoods.jl +++ /dev/null @@ -1,257 +0,0 @@ -# Context version -struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext - loglikelihoods::A - context::Ctx -end - -function PointwiseLikelihoodContext( - likelihoods=OrderedDict{VarName,Vector{Float64}}(), - context::AbstractContext=LikelihoodContext(), -) - return PointwiseLikelihoodContext{typeof(likelihoods),typeof(context)}( - likelihoods, context - ) -end - -NodeTrait(::PointwiseLikelihoodContext) = IsParent() -childcontext(context::PointwiseLikelihoodContext) = context.context -function setchildcontext(context::PointwiseLikelihoodContext, child) - return PointwiseLikelihoodContext(context.loglikelihoods, child) -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.loglikelihoods - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Float64}}, - vn::VarName, - logp::Real, -) - return context.loglikelihoods[vn] = logp -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.loglikelihoods - ℓ = get!(lookup, string(vn), Float64[]) - return push!(ℓ, logp) -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}}, - vn::VarName, - logp::Real, -) - return context.loglikelihoods[string(vn)] = logp -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}}, - vn::String, - logp::Real, -) - lookup = context.loglikelihoods - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}}, - vn::String, - logp::Real, -) - return context.loglikelihoods[vn] = logp -end - -function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) - # Defer literal `observe` to child-context. - return tilde_observe!!(context.context, right, left, vi) -end -function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vi) - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `tilde_observe!`. - logp, vi = tilde_observe(context.context, right, left, vi) - - # Track loglikelihood value. - push!(context, vn, logp) - - return left, acclogp!!(vi, logp) -end - -function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) - # Defer literal `observe` to child-context. - return dot_tilde_observe!!(context.context, right, left, vi) -end -function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vi) - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `dot_tilde_observe!`. - - # We want to treat `.~` as a collection of independent observations, - # hence we need the `logp` for each of them. Broadcasting the univariate - # `tilde_obseve` does exactly this. - logps = _pointwise_tilde_observe(context.context, right, left, vi) - - # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. - _, _, vns = unwrap_right_left_vns(right, left, vn) - for (vn, logp) in zip(vns, logps) - # Track loglikelihood value. - push!(context, vn, logp) - end - - return left, acclogp!!(vi, sum(logps)) -end - -# FIXME: This is really not a good approach since it needs to stay in sync with -# the `dot_assume` implementations, but as things are _right now_ this is the best we can do. -function _pointwise_tilde_observe(context, right, left, vi) - # We need to drop the `vi` returned. - return broadcast(right, left) do r, l - return first(tilde_observe(context, r, l, vi)) - end -end - -function _pointwise_tilde_observe( - context, right::MultivariateDistribution, left::AbstractMatrix, vi::AbstractVarInfo -) - # We need to drop the `vi` returned. - return map(eachcol(left)) do l - return first(tilde_observe(context, right, l, vi)) - end -end - -""" - pointwise_loglikelihoods(model::Model, chain::Chains, keytype = String) - -Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` -with keys corresponding to symbols of the observations, and values being matrices -of shape `(num_chains, num_samples)`. - -`keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -Currently, only `String` and `VarName` are supported. - -# Notes -Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` -both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an -*observation*) statements can be implemented in three ways: -1. using a `for` loop: -```julia -for i in eachindex(y) - y[i] ~ Normal(μ, σ) -end -``` -2. using `.~`: -```julia -y .~ Normal(μ, σ) -``` -3. using `MvNormal`: -```julia -y ~ MvNormal(fill(μ, n), σ^2 * I) -``` - -In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, -while in (3) `y` will be treated as a _single_ n-dimensional observation. - -This is important to keep in mind, in particular if the computation is used -for downstream computations. - -# Examples -## From chain -```julia-repl -julia> using DynamicPPL, Turing - -julia> @model function demo(xs, y) - s ~ InverseGamma(2, 3) - m ~ Normal(0, √s) - for i in eachindex(xs) - xs[i] ~ Normal(m, √s) - end - - y ~ Normal(m, √s) - end -demo (generic function with 1 method) - -julia> model = demo(randn(3), randn()); - -julia> chain = sample(model, MH(), 10); - -julia> pointwise_loglikelihoods(model, chain) -OrderedDict{String,Array{Float64,2}} with 4 entries: - "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] - "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] - "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] - "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] - -julia> pointwise_loglikelihoods(model, chain, String) -OrderedDict{String,Array{Float64,2}} with 4 entries: - "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] - "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] - "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] - "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] - -julia> pointwise_loglikelihoods(model, chain, VarName) -OrderedDict{VarName,Array{Float64,2}} with 4 entries: - xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333] - xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359] - xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251] - y => [-1.51265; -0.914129; … ; -1.5499; -1.5499] -``` - -## Broadcasting -Note that `x .~ Dist()` will treat `x` as a collection of -_independent_ observations rather than as a single observation. - -```jldoctest; setup = :(using Distributions) -julia> @model function demo(x) - x .~ Normal() - end; - -julia> m = demo([1.0, ]); - -julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first(ℓ[@varname(x[1])]) --1.4189385332046727 - -julia> m = demo([1.0; 1.0]); - -julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) -(-1.4189385332046727, -1.4189385332046727) -``` - -""" -function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} - # Get the data by executing the model once - vi = VarInfo(model) - context = PointwiseLikelihoodContext(OrderedDict{T,Vector{Float64}}()) - - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - for (sample_idx, chain_idx) in iters - # Update the values - setval!(vi, chain, sample_idx, chain_idx) - - # Execute model - model(vi, context) - end - - niters = size(chain, 1) - nchains = size(chain, 3) - loglikelihoods = OrderedDict( - varname => reshape(logliks, niters, nchains) for - (varname, logliks) in context.loglikelihoods - ) - return loglikelihoods -end - -function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) - context = PointwiseLikelihoodContext(OrderedDict{VarName,Vector{Float64}}()) - model(varinfo, context) - return context.loglikelihoods -end diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl new file mode 100644 index 000000000..47b969e6c --- /dev/null +++ b/src/pointwise_logdensities.jl @@ -0,0 +1,395 @@ +# Context version +struct PointwiseLogdensityContext{A,Ctx} <: AbstractContext + logdensities::A + context::Ctx +end + +function PointwiseLogdensityContext( + likelihoods=OrderedDict{VarName,Vector{Float64}}(), + context::AbstractContext=DefaultContext(), +) + return PointwiseLogdensityContext{typeof(likelihoods),typeof(context)}( + likelihoods, context + ) +end + +NodeTrait(::PointwiseLogdensityContext) = IsParent() +childcontext(context::PointwiseLogdensityContext) = context.context +function setchildcontext(context::PointwiseLogdensityContext, child) + return PointwiseLogdensityContext(context.logdensities, child) +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{VarName,Vector{Float64}}}, + vn::VarName, + logp::Real, +) + lookup = context.logdensities + ℓ = get!(lookup, vn, Float64[]) + return push!(ℓ, logp) +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{VarName,Float64}}, + vn::VarName, + logp::Real, +) + return context.logdensities[vn] = logp +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, + vn::VarName, + logp::Real, +) + lookup = context.logdensities + ℓ = get!(lookup, string(vn), Float64[]) + return push!(ℓ, logp) +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, + vn::VarName, + logp::Real, +) + return context.logdensities[string(vn)] = logp +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, + vn::String, + logp::Real, +) + lookup = context.logdensities + ℓ = get!(lookup, vn, Float64[]) + return push!(ℓ, logp) +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, + vn::String, + logp::Real, +) + return context.logdensities[vn] = logp +end + +function _include_prior(context::PointwiseLogdensityContext) + return leafcontext(context) isa Union{PriorContext,DefaultContext} +end +function _include_likelihood(context::PointwiseLogdensityContext) + return leafcontext(context) isa Union{LikelihoodContext,DefaultContext} +end + +function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) + # Defer literal `observe` to child-context. + return tilde_observe!!(context.context, right, left, vi) +end +function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) + # Completely defer to child context if we are not tracking likelihoods. + if !(_include_likelihood(context)) + return tilde_observe!!(context.context, right, left, vn, vi) + end + + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `tilde_observe!`. + logp, vi = tilde_observe(context.context, right, left, vi) + + # Track loglikelihood value. + push!(context, vn, logp) + + return left, acclogp!!(vi, logp) +end + +function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) + # Defer literal `observe` to child-context. + return dot_tilde_observe!!(context.context, right, left, vi) +end +function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) + # Completely defer to child context if we are not tracking likelihoods. + if !(_include_likelihood(context)) + return dot_tilde_observe!!(context.context, right, left, vn, vi) + end + + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `dot_tilde_observe!`. + + # We want to treat `.~` as a collection of independent observations, + # hence we need the `logp` for each of them. Broadcasting the univariate + # `tilde_observe` does exactly this. + logps = _pointwise_tilde_observe(context.context, right, left, vi) + + # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. + _, _, vns = unwrap_right_left_vns(right, left, vn) + for (vn, logp) in zip(vns, logps) + # Track loglikelihood value. + push!(context, vn, logp) + end + + return left, acclogp!!(vi, sum(logps)) +end + +# FIXME: This is really not a good approach since it needs to stay in sync with +# the `dot_assume` implementations, but as things are _right now_ this is the best we can do. +function _pointwise_tilde_observe(context, right, left, vi) + # We need to drop the `vi` returned. + return broadcast(right, left) do r, l + return first(tilde_observe(context, r, l, vi)) + end +end + +function _pointwise_tilde_observe( + context, right::MultivariateDistribution, left::AbstractMatrix, vi::AbstractVarInfo +) + # We need to drop the `vi` returned. + return map(eachcol(left)) do l + return first(tilde_observe(context, right, l, vi)) + end +end + +function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) + !_include_prior(context) && return (tilde_assume!!(context.context, right, vn, vi)) + value, logp, vi = tilde_assume(context.context, right, vn, vi) + # Track loglikelihood value. + push!(context, vn, logp) + return value, acclogp!!(vi, logp) +end + +function dot_tilde_assume!!(context::PointwiseLogdensityContext, right, left, vns, vi) + !_include_prior(context) && + return (dot_tilde_assume!!(context.context, right, left, vns, vi)) + value, logps = _pointwise_tilde_assume(context, right, left, vns, vi) + # Track loglikelihood values. + for (vn, logp) in zip(vns, logps) + push!(context, vn, logp) + end + return value, acclogp!!(vi, sum(logps)) +end + +function _pointwise_tilde_assume(context, right, left, vns, vi) + # We need to drop the `vi` returned. + values_and_logps = broadcast(right, left, vns) do r, l, vn + # HACK(torfjelde): This drops the `vi` returned, which means the `vi` is not updated + # in case of immutable varinfos. But a) atm we're only using mutable varinfos for this, + # and b) even if the variables aren't stored in the vi correctly, we're not going to use + # this vi for anything downstream anyways, i.e. I don't see a case where this would matter + # for this particular use case. + val, logp, _ = tilde_assume(context, r, vn, vi) + return val, logp + end + return map(first, values_and_logps), map(last, values_and_logps) +end +function _pointwise_tilde_assume( + context, right::MultivariateDistribution, left::AbstractMatrix, vns, vi +) + # We need to drop the `vi` returned. + values_and_logps = map(eachcol(left), vns) do l, vn + val, logp, _ = tilde_assume(context, right, vn, vi) + return val, logp + end + # HACK(torfjelde): Due to the way we handle `.~`, we should use `recombine` to stay consistent. + # But this also means that we need to first flatten the entire `values` component before recombining. + values = recombine(right, mapreduce(vec ∘ first, vcat, values_and_logps), length(vns)) + return values, map(last, values_and_logps) +end + +""" + pointwise_logdensities(model::Model, chain::Chains, keytype = String) + +Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` +with keys corresponding to symbols of the variables, and values being matrices +of shape `(num_chains, num_samples)`. + +`keytype` specifies what the type of the keys used in the returned `OrderedDict` are. +Currently, only `String` and `VarName` are supported. + +# Notes +Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` +both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an +*observation*) statements can be implemented in three ways: +1. using a `for` loop: +```julia +for i in eachindex(y) + y[i] ~ Normal(μ, σ) +end +``` +2. using `.~`: +```julia +y .~ Normal(μ, σ) +``` +3. using `MvNormal`: +```julia +y ~ MvNormal(fill(μ, n), σ^2 * I) +``` + +In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, +while in (3) `y` will be treated as a _single_ n-dimensional observation. + +This is important to keep in mind, in particular if the computation is used +for downstream computations. + +# Examples +## From chain +```jldoctest pointwise-logdensities-chains; setup=:(using Distributions) +julia> using MCMCChains + +julia> @model function demo(xs, y) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in eachindex(xs) + xs[i] ~ Normal(m, √s) + end + y ~ Normal(m, √s) + end +demo (generic function with 2 methods) + +julia> # Example observations. + model = demo([1.0, 2.0, 3.0], [4.0]); + +julia> # A chain with 3 iterations. + chain = Chains( + reshape(1.:6., 3, 2), + [:s, :m] + ); + +julia> pointwise_logdensities(model, chain) +OrderedDict{String, Matrix{Float64}} with 6 entries: + "s" => [-0.802775; -1.38222; -2.09861;;] + "m" => [-8.91894; -7.51551; -7.46824;;] + "xs[1]" => [-5.41894; -5.26551; -5.63491;;] + "xs[2]" => [-2.91894; -3.51551; -4.13491;;] + "xs[3]" => [-1.41894; -2.26551; -2.96824;;] + "y" => [-0.918939; -1.51551; -2.13491;;] + +julia> pointwise_logdensities(model, chain, String) +OrderedDict{String, Matrix{Float64}} with 6 entries: + "s" => [-0.802775; -1.38222; -2.09861;;] + "m" => [-8.91894; -7.51551; -7.46824;;] + "xs[1]" => [-5.41894; -5.26551; -5.63491;;] + "xs[2]" => [-2.91894; -3.51551; -4.13491;;] + "xs[3]" => [-1.41894; -2.26551; -2.96824;;] + "y" => [-0.918939; -1.51551; -2.13491;;] + +julia> pointwise_logdensities(model, chain, VarName) +OrderedDict{VarName, Matrix{Float64}} with 6 entries: + s => [-0.802775; -1.38222; -2.09861;;] + m => [-8.91894; -7.51551; -7.46824;;] + xs[1] => [-5.41894; -5.26551; -5.63491;;] + xs[2] => [-2.91894; -3.51551; -4.13491;;] + xs[3] => [-1.41894; -2.26551; -2.96824;;] + y => [-0.918939; -1.51551; -2.13491;;] +``` + +## Broadcasting +Note that `x .~ Dist()` will treat `x` as a collection of +_independent_ observations rather than as a single observation. + +```jldoctest; setup = :(using Distributions) +julia> @model function demo(x) + x .~ Normal() + end; + +julia> m = demo([1.0, ]); + +julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x[1])]) +-1.4189385332046727 + +julia> m = demo([1.0; 1.0]); + +julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) +(-1.4189385332046727, -1.4189385332046727) +``` + +""" +function pointwise_logdensities( + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() +) where {T} + # Get the data by executing the model once + vi = VarInfo(model) + point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context) + + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + for (sample_idx, chain_idx) in iters + # Update the values + setval!(vi, chain, sample_idx, chain_idx) + + # Execute model + model(vi, point_context) + end + + niters = size(chain, 1) + nchains = size(chain, 3) + logdensities = OrderedDict( + varname => reshape(logliks, niters, nchains) for + (varname, logliks) in point_context.logdensities + ) + return logdensities +end + +function pointwise_logdensities( + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() +) + point_context = PointwiseLogdensityContext( + OrderedDict{VarName,Vector{Float64}}(), context + ) + model(varinfo, point_context) + return point_context.logdensities +end + +""" + pointwise_loglikelihoods(model, chain[, keytype, context]) + +Compute the pointwise log-likelihoods of the model given the chain. +This is the same as `pointwise_logdensities(model, chain, context)`, but only +including the likelihood terms. +See also: [`pointwise_logdensities`](@ref). +""" +function pointwise_loglikelihoods( + model::Model, + chain, + keytype::Type{T}=String, + context::AbstractContext=LikelihoodContext(), +) where {T} + if !(leafcontext(context) isa LikelihoodContext) + throw(ArgumentError("Leaf context should be a LikelihoodContext")) + end + + return pointwise_logdensities(model, chain, T, context) +end + +function pointwise_loglikelihoods( + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=LikelihoodContext() +) + if !(leafcontext(context) isa LikelihoodContext) + throw(ArgumentError("Leaf context should be a LikelihoodContext")) + end + + return pointwise_logdensities(model, varinfo, context) +end + +""" + pointwise_prior_logdensities(model, chain[, keytype, context]) + +Compute the pointwise log-prior-densities of the model given the chain. +This is the same as `pointwise_logdensities(model, chain, context)`, but only +including the prior terms. +See also: [`pointwise_logdensities`](@ref). +""" +function pointwise_prior_logdensities( + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=PriorContext() +) where {T} + if !(leafcontext(context) isa PriorContext) + throw(ArgumentError("Leaf context should be a PriorContext")) + end + + return pointwise_logdensities(model, chain, T, context) +end + +function pointwise_prior_logdensities( + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext() +) + if !(leafcontext(context) isa PriorContext) + throw(ArgumentError("Leaf context should be a PriorContext")) + end + + return pointwise_logdensities(model, varinfo, context) +end diff --git a/src/test_utils.jl b/src/test_utils.jl index 6f7481c40..8489f2684 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1042,4 +1042,44 @@ function test_context_interface(context) end end +""" +Context that multiplies each log-prior by mod +used to test whether varwise_logpriors respects child-context. +""" +struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext + mod::T + context::Ctx +end +function TestLogModifyingChildContext( + mod=1.2, context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext() +) + return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context) +end + +DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context +function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child) + return TestLogModifyingChildContext(context.mod, child) +end +function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi) + value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) + return value, logp * context.mod, vi +end +function DynamicPPL.dot_tilde_assume( + context::TestLogModifyingChildContext, right, left, vn, vi +) + value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi) + return value, logp * context.mod, vi +end +function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) + logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) + return logp * context.mod, vi +end +function DynamicPPL.dot_tilde_observe( + context::TestLogModifyingChildContext, right, left, vi +) + logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi) + return logp * context.mod, vi +end + end diff --git a/test/contexts.jl b/test/contexts.jl index 11e2c99b7..4ec9ff945 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -8,7 +8,7 @@ using DynamicPPL: NodeTrait, IsLeaf, IsParent, - PointwiseLikelihoodContext, + PointwiseLogdensityContext, contextual_isassumption, ConditionContext, hasconditioned, @@ -67,7 +67,7 @@ end SamplingContext(), MiniBatchContext(DefaultContext(), 0.0), PrefixContext{:x}(DefaultContext()), - PointwiseLikelihoodContext(), + PointwiseLogdensityContext(), ConditionContext((x=1.0,)), ConditionContext((x=1.0,), ParentContext(ConditionContext((y=2.0,)))), ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl deleted file mode 100644 index 1075ce333..000000000 --- a/test/loglikelihoods.jl +++ /dev/null @@ -1,24 +0,0 @@ -@testset "loglikelihoods.jl" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(m) - - # Instantiate a `VarInfo` with the example values. - vi = VarInfo(m) - for vn in DynamicPPL.TestUtils.varnames(m) - vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) - end - - # Compute the pointwise loglikelihoods. - lls = pointwise_loglikelihoods(m, vi) - - if isempty(lls) - # One of the models with literal observations, so we just skip. - continue - end - - loglikelihood = sum(sum, values(lls)) - loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) - - @test loglikelihood ≈ loglikelihood_true - end -end diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl new file mode 100644 index 000000000..93b7c59be --- /dev/null +++ b/test/pointwise_logdensities.jl @@ -0,0 +1,101 @@ +@testset "logdensities_likelihoods.jl" begin + mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) + mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + example_values = DynamicPPL.TestUtils.rand_prior_true(model) + + # Instantiate a `VarInfo` with the example values. + vi = VarInfo(model) + for vn in DynamicPPL.TestUtils.varnames(model) + vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true( + model, example_values... + ) + logp_true = logprior(model, vi) + + # Compute the pointwise loglikelihoods. + lls = pointwise_loglikelihoods(model, vi) + if isempty(lls) + # One of the models with literal observations, so we'll set this to 0 for subsequent comparisons. + loglikelihood_true = 0.0 + else + @test [:x] == unique(DynamicPPL.getsym.(keys(lls))) + loglikelihood_sum = sum(sum, values(lls)) + @test loglikelihood_sum ≈ loglikelihood_true + end + + # Compute the pointwise logdensities of the priors. + lps_prior = pointwise_prior_logdensities(model, vi) + @test :x ∉ DynamicPPL.getsym.(keys(lps_prior)) + logp = sum(sum, values(lps_prior)) + @test logp ≈ logp_true + + # Compute both likelihood and logdensity of prior + # using the default DefaultContex + lps = pointwise_logdensities(model, vi) + logp = sum(sum, values(lps)) + @test logp ≈ (logp_true + loglikelihood_true) + + # Test that modifications of Setup are picked up + lps = pointwise_logdensities(model, vi, mod_ctx2) + logp = sum(sum, values(lps)) + @test logp ≈ (logp_true + loglikelihood_true) * 1.2 * 1.4 + end +end + +@testset "pointwise_logdensities chain" begin + # We'll just test one, since `pointwise_logdensities(::Model, ::AbstractVarInfo)` is tested extensively, + # and this is what is used to implement `pointwise_logdensities(::Model, ::Chains)`. This test suite is just + # to ensure that we don't accidentally break the the version on `Chains`. + model = DynamicPPL.TestUtils.demo_dot_assume_dot_observe() + # FIXME(torfjelde): Make use of `varname_and_value_leaves` once we've introduced + # an impl of this for containers. + # NOTE(torfjelde): This only returns the varnames of the _random_ variables, i.e. excl. observed. + vns = DynamicPPL.TestUtils.varnames(model) + # Get some random `NamedTuple` samples from the prior. + num_iters = 3 + vals = [DynamicPPL.TestUtils.rand_prior_true(model) for _ in 1:num_iters] + # Concatenate the vector representations and create a `Chains` from it. + vals_arr = reduce(hcat, mapreduce(DynamicPPL.tovec, vcat, values(nt)) for nt in vals) + chain = Chains(permutedims(vals_arr), map(Symbol, vns)) + + # Compute the different pointwise logdensities. + logjoints_pointwise = pointwise_logdensities(model, chain) + logpriors_pointwise = pointwise_prior_logdensities(model, chain) + loglikelihoods_pointwise = pointwise_loglikelihoods(model, chain) + + # Check that they contain the correct variables. + @test all(string(vn) in keys(logjoints_pointwise) for vn in vns) + @test all(string(vn) in keys(logpriors_pointwise) for vn in vns) + @test !any(Base.Fix2(startswith, "x"), keys(logpriors_pointwise)) + @test !any(string(vn) in keys(loglikelihoods_pointwise) for vn in vns) + @test all(Base.Fix2(startswith, "x"), keys(loglikelihoods_pointwise)) + + # Get the sum of the logjoints for each of the iterations. + logjoints = [ + sum(logjoints_pointwise[vn][idx] for vn in keys(logjoints_pointwise)) for + idx in 1:num_iters + ] + logpriors = [ + sum(logpriors_pointwise[vn][idx] for vn in keys(logpriors_pointwise)) for + idx in 1:num_iters + ] + loglikelihoods = [ + sum(loglikelihoods_pointwise[vn][idx] for vn in keys(loglikelihoods_pointwise)) for + idx in 1:num_iters + ] + + for (val, logjoint, logprior, loglikelihood) in + zip(vals, logjoints, logpriors, loglikelihoods) + # Compare true logjoint with the one obtained from `pointwise_logdensities`. + logjoint_true = DynamicPPL.TestUtils.logjoint_true(model, val...) + logprior_true = DynamicPPL.TestUtils.logprior_true(model, val...) + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(model, val...) + + @test logjoint ≈ logjoint_true + @test logprior ≈ logprior_true + @test loglikelihood ≈ loglikelihood_true + end +end diff --git a/test/runtests.jl b/test/runtests.jl index aa0883708..b9a1d92bd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,6 +22,8 @@ using Pkg using Random using Serialization using Test +using Distributions +using LinearAlgebra # Diagonal using DynamicPPL: getargs_dottilde, getargs_tilde, Selector @@ -53,7 +55,7 @@ include("test_util.jl") include("serialization.jl") - include("loglikelihoods.jl") + include("pointwise_logdensities.jl") include("lkj.jl") end @@ -93,6 +95,9 @@ include("test_util.jl") # Errors from macros sometimes result in `LoadError: LoadError:` # rather than `LoadError:`, depending on Julia version. r"ERROR: (LoadError:\s)+", + # Older versions do not have `;;]` but instead just `]` at end of the line + # => need to treat `;;]` and `]` as the same, i.e. ignore them if at the end of a line + r"(;;){0,1}\]$"m, ] doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters) end