diff --git a/Project.toml b/Project.toml index 700f040c7..95ce8cde2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.29.1" +version = "0.29.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/api.md b/docs/src/api.md index 97c48316e..156b51e03 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -124,10 +124,14 @@ Return values of the model function for a collection of samples can be obtained generated_quantities ``` -For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). +For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using +[`pointwise_prior_logdensities`](@ref) or both, i.e. all variables, using +[`pointwise_logdensities`](@ref). ```@docs +pointwise_logdensities pointwise_loglikelihoods +pointwise_prior_logdensities ``` For converting a chain into a format that can more easily be fed into a `Model` again, for example using `condition`, you can use [`value_iterator_from_chain`](@ref). diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index eb027b45b..777c770d4 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -115,6 +115,8 @@ export AbstractVarInfo, # Convenience functions logprior, logjoint, + pointwise_prior_logdensities, + pointwise_logdensities, pointwise_loglikelihoods, condition, decondition, @@ -181,7 +183,7 @@ include("varinfo.jl") include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") -include("loglikelihoods.jl") +include("pointwise_logdensities.jl") include("submodel_macro.jl") include("test_utils.jl") include("transforming.jl") diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl deleted file mode 100644 index 227a70889..000000000 --- a/src/loglikelihoods.jl +++ /dev/null @@ -1,257 +0,0 @@ -# Context version -struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext - loglikelihoods::A - context::Ctx -end - -function PointwiseLikelihoodContext( - likelihoods=OrderedDict{VarName,Vector{Float64}}(), - context::AbstractContext=LikelihoodContext(), -) - return PointwiseLikelihoodContext{typeof(likelihoods),typeof(context)}( - likelihoods, context - ) -end - -NodeTrait(::PointwiseLikelihoodContext) = IsParent() -childcontext(context::PointwiseLikelihoodContext) = context.context -function setchildcontext(context::PointwiseLikelihoodContext, child) - return PointwiseLikelihoodContext(context.loglikelihoods, child) -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.loglikelihoods - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{VarName,Float64}}, - vn::VarName, - logp::Real, -) - return context.loglikelihoods[vn] = logp -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.loglikelihoods - ℓ = get!(lookup, string(vn), Float64[]) - return push!(ℓ, logp) -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}}, - vn::VarName, - logp::Real, -) - return context.loglikelihoods[string(vn)] = logp -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Vector{Float64}}}, - vn::String, - logp::Real, -) - lookup = context.loglikelihoods - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) -end - -function Base.push!( - context::PointwiseLikelihoodContext{<:AbstractDict{String,Float64}}, - vn::String, - logp::Real, -) - return context.loglikelihoods[vn] = logp -end - -function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) - # Defer literal `observe` to child-context. - return tilde_observe!!(context.context, right, left, vi) -end -function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vi) - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `tilde_observe!`. - logp, vi = tilde_observe(context.context, right, left, vi) - - # Track loglikelihood value. - push!(context, vn, logp) - - return left, acclogp!!(vi, logp) -end - -function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) - # Defer literal `observe` to child-context. - return dot_tilde_observe!!(context.context, right, left, vi) -end -function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vi) - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `dot_tilde_observe!`. - - # We want to treat `.~` as a collection of independent observations, - # hence we need the `logp` for each of them. Broadcasting the univariate - # `tilde_obseve` does exactly this. - logps = _pointwise_tilde_observe(context.context, right, left, vi) - - # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. - _, _, vns = unwrap_right_left_vns(right, left, vn) - for (vn, logp) in zip(vns, logps) - # Track loglikelihood value. - push!(context, vn, logp) - end - - return left, acclogp!!(vi, sum(logps)) -end - -# FIXME: This is really not a good approach since it needs to stay in sync with -# the `dot_assume` implementations, but as things are _right now_ this is the best we can do. -function _pointwise_tilde_observe(context, right, left, vi) - # We need to drop the `vi` returned. - return broadcast(right, left) do r, l - return first(tilde_observe(context, r, l, vi)) - end -end - -function _pointwise_tilde_observe( - context, right::MultivariateDistribution, left::AbstractMatrix, vi::AbstractVarInfo -) - # We need to drop the `vi` returned. - return map(eachcol(left)) do l - return first(tilde_observe(context, right, l, vi)) - end -end - -""" - pointwise_loglikelihoods(model::Model, chain::Chains, keytype = String) - -Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` -with keys corresponding to symbols of the observations, and values being matrices -of shape `(num_chains, num_samples)`. - -`keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -Currently, only `String` and `VarName` are supported. - -# Notes -Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` -both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an -*observation*) statements can be implemented in three ways: -1. using a `for` loop: -```julia -for i in eachindex(y) - y[i] ~ Normal(μ, σ) -end -``` -2. using `.~`: -```julia -y .~ Normal(μ, σ) -``` -3. using `MvNormal`: -```julia -y ~ MvNormal(fill(μ, n), σ^2 * I) -``` - -In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, -while in (3) `y` will be treated as a _single_ n-dimensional observation. - -This is important to keep in mind, in particular if the computation is used -for downstream computations. - -# Examples -## From chain -```julia-repl -julia> using DynamicPPL, Turing - -julia> @model function demo(xs, y) - s ~ InverseGamma(2, 3) - m ~ Normal(0, √s) - for i in eachindex(xs) - xs[i] ~ Normal(m, √s) - end - - y ~ Normal(m, √s) - end -demo (generic function with 1 method) - -julia> model = demo(randn(3), randn()); - -julia> chain = sample(model, MH(), 10); - -julia> pointwise_loglikelihoods(model, chain) -OrderedDict{String,Array{Float64,2}} with 4 entries: - "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] - "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] - "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] - "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] - -julia> pointwise_loglikelihoods(model, chain, String) -OrderedDict{String,Array{Float64,2}} with 4 entries: - "xs[1]" => [-1.42932; -2.68123; … ; -1.66333; -1.66333] - "xs[2]" => [-1.6724; -0.861339; … ; -1.62359; -1.62359] - "xs[3]" => [-1.42862; -2.67573; … ; -1.66251; -1.66251] - "y" => [-1.51265; -0.914129; … ; -1.5499; -1.5499] - -julia> pointwise_loglikelihoods(model, chain, VarName) -OrderedDict{VarName,Array{Float64,2}} with 4 entries: - xs[1] => [-1.42932; -2.68123; … ; -1.66333; -1.66333] - xs[2] => [-1.6724; -0.861339; … ; -1.62359; -1.62359] - xs[3] => [-1.42862; -2.67573; … ; -1.66251; -1.66251] - y => [-1.51265; -0.914129; … ; -1.5499; -1.5499] -``` - -## Broadcasting -Note that `x .~ Dist()` will treat `x` as a collection of -_independent_ observations rather than as a single observation. - -```jldoctest; setup = :(using Distributions) -julia> @model function demo(x) - x .~ Normal() - end; - -julia> m = demo([1.0, ]); - -julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first(ℓ[@varname(x[1])]) --1.4189385332046727 - -julia> m = demo([1.0; 1.0]); - -julia> ℓ = pointwise_loglikelihoods(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) -(-1.4189385332046727, -1.4189385332046727) -``` - -""" -function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} - # Get the data by executing the model once - vi = VarInfo(model) - context = PointwiseLikelihoodContext(OrderedDict{T,Vector{Float64}}()) - - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - for (sample_idx, chain_idx) in iters - # Update the values - setval!(vi, chain, sample_idx, chain_idx) - - # Execute model - model(vi, context) - end - - niters = size(chain, 1) - nchains = size(chain, 3) - loglikelihoods = OrderedDict( - varname => reshape(logliks, niters, nchains) for - (varname, logliks) in context.loglikelihoods - ) - return loglikelihoods -end - -function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) - context = PointwiseLikelihoodContext(OrderedDict{VarName,Vector{Float64}}()) - model(varinfo, context) - return context.loglikelihoods -end diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl new file mode 100644 index 000000000..47b969e6c --- /dev/null +++ b/src/pointwise_logdensities.jl @@ -0,0 +1,395 @@ +# Context version +struct PointwiseLogdensityContext{A,Ctx} <: AbstractContext + logdensities::A + context::Ctx +end + +function PointwiseLogdensityContext( + likelihoods=OrderedDict{VarName,Vector{Float64}}(), + context::AbstractContext=DefaultContext(), +) + return PointwiseLogdensityContext{typeof(likelihoods),typeof(context)}( + likelihoods, context + ) +end + +NodeTrait(::PointwiseLogdensityContext) = IsParent() +childcontext(context::PointwiseLogdensityContext) = context.context +function setchildcontext(context::PointwiseLogdensityContext, child) + return PointwiseLogdensityContext(context.logdensities, child) +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{VarName,Vector{Float64}}}, + vn::VarName, + logp::Real, +) + lookup = context.logdensities + ℓ = get!(lookup, vn, Float64[]) + return push!(ℓ, logp) +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{VarName,Float64}}, + vn::VarName, + logp::Real, +) + return context.logdensities[vn] = logp +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, + vn::VarName, + logp::Real, +) + lookup = context.logdensities + ℓ = get!(lookup, string(vn), Float64[]) + return push!(ℓ, logp) +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, + vn::VarName, + logp::Real, +) + return context.logdensities[string(vn)] = logp +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, + vn::String, + logp::Real, +) + lookup = context.logdensities + ℓ = get!(lookup, vn, Float64[]) + return push!(ℓ, logp) +end + +function Base.push!( + context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, + vn::String, + logp::Real, +) + return context.logdensities[vn] = logp +end + +function _include_prior(context::PointwiseLogdensityContext) + return leafcontext(context) isa Union{PriorContext,DefaultContext} +end +function _include_likelihood(context::PointwiseLogdensityContext) + return leafcontext(context) isa Union{LikelihoodContext,DefaultContext} +end + +function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) + # Defer literal `observe` to child-context. + return tilde_observe!!(context.context, right, left, vi) +end +function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) + # Completely defer to child context if we are not tracking likelihoods. + if !(_include_likelihood(context)) + return tilde_observe!!(context.context, right, left, vn, vi) + end + + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `tilde_observe!`. + logp, vi = tilde_observe(context.context, right, left, vi) + + # Track loglikelihood value. + push!(context, vn, logp) + + return left, acclogp!!(vi, logp) +end + +function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) + # Defer literal `observe` to child-context. + return dot_tilde_observe!!(context.context, right, left, vi) +end +function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) + # Completely defer to child context if we are not tracking likelihoods. + if !(_include_likelihood(context)) + return dot_tilde_observe!!(context.context, right, left, vn, vi) + end + + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `dot_tilde_observe!`. + + # We want to treat `.~` as a collection of independent observations, + # hence we need the `logp` for each of them. Broadcasting the univariate + # `tilde_observe` does exactly this. + logps = _pointwise_tilde_observe(context.context, right, left, vi) + + # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. + _, _, vns = unwrap_right_left_vns(right, left, vn) + for (vn, logp) in zip(vns, logps) + # Track loglikelihood value. + push!(context, vn, logp) + end + + return left, acclogp!!(vi, sum(logps)) +end + +# FIXME: This is really not a good approach since it needs to stay in sync with +# the `dot_assume` implementations, but as things are _right now_ this is the best we can do. +function _pointwise_tilde_observe(context, right, left, vi) + # We need to drop the `vi` returned. + return broadcast(right, left) do r, l + return first(tilde_observe(context, r, l, vi)) + end +end + +function _pointwise_tilde_observe( + context, right::MultivariateDistribution, left::AbstractMatrix, vi::AbstractVarInfo +) + # We need to drop the `vi` returned. + return map(eachcol(left)) do l + return first(tilde_observe(context, right, l, vi)) + end +end + +function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) + !_include_prior(context) && return (tilde_assume!!(context.context, right, vn, vi)) + value, logp, vi = tilde_assume(context.context, right, vn, vi) + # Track loglikelihood value. + push!(context, vn, logp) + return value, acclogp!!(vi, logp) +end + +function dot_tilde_assume!!(context::PointwiseLogdensityContext, right, left, vns, vi) + !_include_prior(context) && + return (dot_tilde_assume!!(context.context, right, left, vns, vi)) + value, logps = _pointwise_tilde_assume(context, right, left, vns, vi) + # Track loglikelihood values. + for (vn, logp) in zip(vns, logps) + push!(context, vn, logp) + end + return value, acclogp!!(vi, sum(logps)) +end + +function _pointwise_tilde_assume(context, right, left, vns, vi) + # We need to drop the `vi` returned. + values_and_logps = broadcast(right, left, vns) do r, l, vn + # HACK(torfjelde): This drops the `vi` returned, which means the `vi` is not updated + # in case of immutable varinfos. But a) atm we're only using mutable varinfos for this, + # and b) even if the variables aren't stored in the vi correctly, we're not going to use + # this vi for anything downstream anyways, i.e. I don't see a case where this would matter + # for this particular use case. + val, logp, _ = tilde_assume(context, r, vn, vi) + return val, logp + end + return map(first, values_and_logps), map(last, values_and_logps) +end +function _pointwise_tilde_assume( + context, right::MultivariateDistribution, left::AbstractMatrix, vns, vi +) + # We need to drop the `vi` returned. + values_and_logps = map(eachcol(left), vns) do l, vn + val, logp, _ = tilde_assume(context, right, vn, vi) + return val, logp + end + # HACK(torfjelde): Due to the way we handle `.~`, we should use `recombine` to stay consistent. + # But this also means that we need to first flatten the entire `values` component before recombining. + values = recombine(right, mapreduce(vec ∘ first, vcat, values_and_logps), length(vns)) + return values, map(last, values_and_logps) +end + +""" + pointwise_logdensities(model::Model, chain::Chains, keytype = String) + +Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` +with keys corresponding to symbols of the variables, and values being matrices +of shape `(num_chains, num_samples)`. + +`keytype` specifies what the type of the keys used in the returned `OrderedDict` are. +Currently, only `String` and `VarName` are supported. + +# Notes +Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` +both being `<:Real`. Then the *observe* (i.e. when the left-hand side is an +*observation*) statements can be implemented in three ways: +1. using a `for` loop: +```julia +for i in eachindex(y) + y[i] ~ Normal(μ, σ) +end +``` +2. using `.~`: +```julia +y .~ Normal(μ, σ) +``` +3. using `MvNormal`: +```julia +y ~ MvNormal(fill(μ, n), σ^2 * I) +``` + +In (1) and (2), `y` will be treated as a collection of `n` i.i.d. 1-dimensional variables, +while in (3) `y` will be treated as a _single_ n-dimensional observation. + +This is important to keep in mind, in particular if the computation is used +for downstream computations. + +# Examples +## From chain +```jldoctest pointwise-logdensities-chains; setup=:(using Distributions) +julia> using MCMCChains + +julia> @model function demo(xs, y) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in eachindex(xs) + xs[i] ~ Normal(m, √s) + end + y ~ Normal(m, √s) + end +demo (generic function with 2 methods) + +julia> # Example observations. + model = demo([1.0, 2.0, 3.0], [4.0]); + +julia> # A chain with 3 iterations. + chain = Chains( + reshape(1.:6., 3, 2), + [:s, :m] + ); + +julia> pointwise_logdensities(model, chain) +OrderedDict{String, Matrix{Float64}} with 6 entries: + "s" => [-0.802775; -1.38222; -2.09861;;] + "m" => [-8.91894; -7.51551; -7.46824;;] + "xs[1]" => [-5.41894; -5.26551; -5.63491;;] + "xs[2]" => [-2.91894; -3.51551; -4.13491;;] + "xs[3]" => [-1.41894; -2.26551; -2.96824;;] + "y" => [-0.918939; -1.51551; -2.13491;;] + +julia> pointwise_logdensities(model, chain, String) +OrderedDict{String, Matrix{Float64}} with 6 entries: + "s" => [-0.802775; -1.38222; -2.09861;;] + "m" => [-8.91894; -7.51551; -7.46824;;] + "xs[1]" => [-5.41894; -5.26551; -5.63491;;] + "xs[2]" => [-2.91894; -3.51551; -4.13491;;] + "xs[3]" => [-1.41894; -2.26551; -2.96824;;] + "y" => [-0.918939; -1.51551; -2.13491;;] + +julia> pointwise_logdensities(model, chain, VarName) +OrderedDict{VarName, Matrix{Float64}} with 6 entries: + s => [-0.802775; -1.38222; -2.09861;;] + m => [-8.91894; -7.51551; -7.46824;;] + xs[1] => [-5.41894; -5.26551; -5.63491;;] + xs[2] => [-2.91894; -3.51551; -4.13491;;] + xs[3] => [-1.41894; -2.26551; -2.96824;;] + y => [-0.918939; -1.51551; -2.13491;;] +``` + +## Broadcasting +Note that `x .~ Dist()` will treat `x` as a collection of +_independent_ observations rather than as a single observation. + +```jldoctest; setup = :(using Distributions) +julia> @model function demo(x) + x .~ Normal() + end; + +julia> m = demo([1.0, ]); + +julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first(ℓ[@varname(x[1])]) +-1.4189385332046727 + +julia> m = demo([1.0; 1.0]); + +julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) +(-1.4189385332046727, -1.4189385332046727) +``` + +""" +function pointwise_logdensities( + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() +) where {T} + # Get the data by executing the model once + vi = VarInfo(model) + point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context) + + iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + for (sample_idx, chain_idx) in iters + # Update the values + setval!(vi, chain, sample_idx, chain_idx) + + # Execute model + model(vi, point_context) + end + + niters = size(chain, 1) + nchains = size(chain, 3) + logdensities = OrderedDict( + varname => reshape(logliks, niters, nchains) for + (varname, logliks) in point_context.logdensities + ) + return logdensities +end + +function pointwise_logdensities( + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() +) + point_context = PointwiseLogdensityContext( + OrderedDict{VarName,Vector{Float64}}(), context + ) + model(varinfo, point_context) + return point_context.logdensities +end + +""" + pointwise_loglikelihoods(model, chain[, keytype, context]) + +Compute the pointwise log-likelihoods of the model given the chain. +This is the same as `pointwise_logdensities(model, chain, context)`, but only +including the likelihood terms. +See also: [`pointwise_logdensities`](@ref). +""" +function pointwise_loglikelihoods( + model::Model, + chain, + keytype::Type{T}=String, + context::AbstractContext=LikelihoodContext(), +) where {T} + if !(leafcontext(context) isa LikelihoodContext) + throw(ArgumentError("Leaf context should be a LikelihoodContext")) + end + + return pointwise_logdensities(model, chain, T, context) +end + +function pointwise_loglikelihoods( + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=LikelihoodContext() +) + if !(leafcontext(context) isa LikelihoodContext) + throw(ArgumentError("Leaf context should be a LikelihoodContext")) + end + + return pointwise_logdensities(model, varinfo, context) +end + +""" + pointwise_prior_logdensities(model, chain[, keytype, context]) + +Compute the pointwise log-prior-densities of the model given the chain. +This is the same as `pointwise_logdensities(model, chain, context)`, but only +including the prior terms. +See also: [`pointwise_logdensities`](@ref). +""" +function pointwise_prior_logdensities( + model::Model, chain, keytype::Type{T}=String, context::AbstractContext=PriorContext() +) where {T} + if !(leafcontext(context) isa PriorContext) + throw(ArgumentError("Leaf context should be a PriorContext")) + end + + return pointwise_logdensities(model, chain, T, context) +end + +function pointwise_prior_logdensities( + model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext() +) + if !(leafcontext(context) isa PriorContext) + throw(ArgumentError("Leaf context should be a PriorContext")) + end + + return pointwise_logdensities(model, varinfo, context) +end diff --git a/src/test_utils.jl b/src/test_utils.jl index 6f7481c40..8489f2684 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1042,4 +1042,44 @@ function test_context_interface(context) end end +""" +Context that multiplies each log-prior by mod +used to test whether varwise_logpriors respects child-context. +""" +struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext + mod::T + context::Ctx +end +function TestLogModifyingChildContext( + mod=1.2, context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext() +) + return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context) +end + +DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context +function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child) + return TestLogModifyingChildContext(context.mod, child) +end +function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi) + value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) + return value, logp * context.mod, vi +end +function DynamicPPL.dot_tilde_assume( + context::TestLogModifyingChildContext, right, left, vn, vi +) + value, logp, vi = DynamicPPL.dot_tilde_assume(context.context, right, left, vn, vi) + return value, logp * context.mod, vi +end +function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) + logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) + return logp * context.mod, vi +end +function DynamicPPL.dot_tilde_observe( + context::TestLogModifyingChildContext, right, left, vi +) + logp, vi = DynamicPPL.dot_tilde_observe(context.context, right, left, vi) + return logp * context.mod, vi +end + end diff --git a/test/contexts.jl b/test/contexts.jl index 11e2c99b7..4ec9ff945 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -8,7 +8,7 @@ using DynamicPPL: NodeTrait, IsLeaf, IsParent, - PointwiseLikelihoodContext, + PointwiseLogdensityContext, contextual_isassumption, ConditionContext, hasconditioned, @@ -67,7 +67,7 @@ end SamplingContext(), MiniBatchContext(DefaultContext(), 0.0), PrefixContext{:x}(DefaultContext()), - PointwiseLikelihoodContext(), + PointwiseLogdensityContext(), ConditionContext((x=1.0,)), ConditionContext((x=1.0,), ParentContext(ConditionContext((y=2.0,)))), ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl deleted file mode 100644 index 1075ce333..000000000 --- a/test/loglikelihoods.jl +++ /dev/null @@ -1,24 +0,0 @@ -@testset "loglikelihoods.jl" begin - @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.rand_prior_true(m) - - # Instantiate a `VarInfo` with the example values. - vi = VarInfo(m) - for vn in DynamicPPL.TestUtils.varnames(m) - vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) - end - - # Compute the pointwise loglikelihoods. - lls = pointwise_loglikelihoods(m, vi) - - if isempty(lls) - # One of the models with literal observations, so we just skip. - continue - end - - loglikelihood = sum(sum, values(lls)) - loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) - - @test loglikelihood ≈ loglikelihood_true - end -end diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl new file mode 100644 index 000000000..93b7c59be --- /dev/null +++ b/test/pointwise_logdensities.jl @@ -0,0 +1,101 @@ +@testset "logdensities_likelihoods.jl" begin + mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) + mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + example_values = DynamicPPL.TestUtils.rand_prior_true(model) + + # Instantiate a `VarInfo` with the example values. + vi = VarInfo(model) + for vn in DynamicPPL.TestUtils.varnames(model) + vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) + end + + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true( + model, example_values... + ) + logp_true = logprior(model, vi) + + # Compute the pointwise loglikelihoods. + lls = pointwise_loglikelihoods(model, vi) + if isempty(lls) + # One of the models with literal observations, so we'll set this to 0 for subsequent comparisons. + loglikelihood_true = 0.0 + else + @test [:x] == unique(DynamicPPL.getsym.(keys(lls))) + loglikelihood_sum = sum(sum, values(lls)) + @test loglikelihood_sum ≈ loglikelihood_true + end + + # Compute the pointwise logdensities of the priors. + lps_prior = pointwise_prior_logdensities(model, vi) + @test :x ∉ DynamicPPL.getsym.(keys(lps_prior)) + logp = sum(sum, values(lps_prior)) + @test logp ≈ logp_true + + # Compute both likelihood and logdensity of prior + # using the default DefaultContex + lps = pointwise_logdensities(model, vi) + logp = sum(sum, values(lps)) + @test logp ≈ (logp_true + loglikelihood_true) + + # Test that modifications of Setup are picked up + lps = pointwise_logdensities(model, vi, mod_ctx2) + logp = sum(sum, values(lps)) + @test logp ≈ (logp_true + loglikelihood_true) * 1.2 * 1.4 + end +end + +@testset "pointwise_logdensities chain" begin + # We'll just test one, since `pointwise_logdensities(::Model, ::AbstractVarInfo)` is tested extensively, + # and this is what is used to implement `pointwise_logdensities(::Model, ::Chains)`. This test suite is just + # to ensure that we don't accidentally break the the version on `Chains`. + model = DynamicPPL.TestUtils.demo_dot_assume_dot_observe() + # FIXME(torfjelde): Make use of `varname_and_value_leaves` once we've introduced + # an impl of this for containers. + # NOTE(torfjelde): This only returns the varnames of the _random_ variables, i.e. excl. observed. + vns = DynamicPPL.TestUtils.varnames(model) + # Get some random `NamedTuple` samples from the prior. + num_iters = 3 + vals = [DynamicPPL.TestUtils.rand_prior_true(model) for _ in 1:num_iters] + # Concatenate the vector representations and create a `Chains` from it. + vals_arr = reduce(hcat, mapreduce(DynamicPPL.tovec, vcat, values(nt)) for nt in vals) + chain = Chains(permutedims(vals_arr), map(Symbol, vns)) + + # Compute the different pointwise logdensities. + logjoints_pointwise = pointwise_logdensities(model, chain) + logpriors_pointwise = pointwise_prior_logdensities(model, chain) + loglikelihoods_pointwise = pointwise_loglikelihoods(model, chain) + + # Check that they contain the correct variables. + @test all(string(vn) in keys(logjoints_pointwise) for vn in vns) + @test all(string(vn) in keys(logpriors_pointwise) for vn in vns) + @test !any(Base.Fix2(startswith, "x"), keys(logpriors_pointwise)) + @test !any(string(vn) in keys(loglikelihoods_pointwise) for vn in vns) + @test all(Base.Fix2(startswith, "x"), keys(loglikelihoods_pointwise)) + + # Get the sum of the logjoints for each of the iterations. + logjoints = [ + sum(logjoints_pointwise[vn][idx] for vn in keys(logjoints_pointwise)) for + idx in 1:num_iters + ] + logpriors = [ + sum(logpriors_pointwise[vn][idx] for vn in keys(logpriors_pointwise)) for + idx in 1:num_iters + ] + loglikelihoods = [ + sum(loglikelihoods_pointwise[vn][idx] for vn in keys(loglikelihoods_pointwise)) for + idx in 1:num_iters + ] + + for (val, logjoint, logprior, loglikelihood) in + zip(vals, logjoints, logpriors, loglikelihoods) + # Compare true logjoint with the one obtained from `pointwise_logdensities`. + logjoint_true = DynamicPPL.TestUtils.logjoint_true(model, val...) + logprior_true = DynamicPPL.TestUtils.logprior_true(model, val...) + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(model, val...) + + @test logjoint ≈ logjoint_true + @test logprior ≈ logprior_true + @test loglikelihood ≈ loglikelihood_true + end +end diff --git a/test/runtests.jl b/test/runtests.jl index aa0883708..b9a1d92bd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,6 +22,8 @@ using Pkg using Random using Serialization using Test +using Distributions +using LinearAlgebra # Diagonal using DynamicPPL: getargs_dottilde, getargs_tilde, Selector @@ -53,7 +55,7 @@ include("test_util.jl") include("serialization.jl") - include("loglikelihoods.jl") + include("pointwise_logdensities.jl") include("lkj.jl") end @@ -93,6 +95,9 @@ include("test_util.jl") # Errors from macros sometimes result in `LoadError: LoadError:` # rather than `LoadError:`, depending on Julia version. r"ERROR: (LoadError:\s)+", + # Older versions do not have `;;]` but instead just `]` at end of the line + # => need to treat `;;]` and `]` as the same, i.e. ignore them if at the end of a line + r"(;;){0,1}\]$"m, ] doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters) end