Skip to content

Commit

Permalink
avoid recording prior components on leaf-prior-context
Browse files Browse the repository at this point in the history
and avoid recording likelihoods when invoked with leaf-Likelihood context
  • Loading branch information
bgctw committed Sep 24, 2024
1 parent d9945d7 commit 656a757
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 27 deletions.
33 changes: 28 additions & 5 deletions src/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,26 @@ function Base.push!(
return context.logdensities[vn] = logp
end


function _include_prior(context::PointwiseLogdensityContext)
return leafcontext(context) isa Union{PriorContext,DefaultContext}
end
function _include_likelihood(context::PointwiseLogdensityContext)
return leafcontext(context) isa Union{LikelihoodContext,DefaultContext}
end



function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi)
# Defer literal `observe` to child-context.
return tilde_observe!!(context.context, right, left, vi)
end
function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi)
# Completely defer to child context if we are not tracking likelihoods.
if !(_include_likelihood(context))
return tilde_observe!!(context.context, right, left, vn, vi)
end

# Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e.
# we have to intercept the call to `tilde_observe!`.
logp, vi = tilde_observe(context.context, right, left, vi)
Expand All @@ -93,6 +108,11 @@ function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, v
return dot_tilde_observe!!(context.context, right, left, vi)
end
function dot_tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi)
# Completely defer to child context if we are not tracking likelihoods.
if !(_include_likelihood(context))
return dot_tilde_observe!!(context.context, right, left, vn, vi)
end

# Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e.
# we have to intercept the call to `dot_tilde_observe!`.

Expand Down Expand Up @@ -132,16 +152,19 @@ end
function tilde_assume(context::PointwiseLogdensityContext, right::Distribution, vn, vi)
#@info "PointwiseLogdensityContext tilde_assume called for $vn"
value, logp, vi = tilde_assume(context.context, right, vn, vi)
push!(context, vn, logp)
if _include_prior(context)
push!(context, vn, logp)
end
return value, logp, vi
end

function dot_tilde_assume(context::PointwiseLogdensityContext, right, left, vns, vi)
#@info "PointwiseLogdensityContext dot_tilde_assume called for $vns"
value, logp, vi_new = dot_tilde_assume(context.context, right, left, vns, vi)
# dispatch recording of log-densities based on type of right
logps = record_dot_tilde_assume(context, right, left, vns, vi, logp)
sum(logps) logp || error("Expected sum of individual logp equal origina, but differed sum($(join(logps, ","))) != $logp_orig")
if _include_prior(context)
logps = record_dot_tilde_assume(context, right, left, vns, vi, logp)
sum(logps) logp || error("Expected sum of individual logp equal origina, but differed sum($(join(logps, ","))) != $logp_orig")
end
return value, logp, vi
end

Expand Down Expand Up @@ -172,7 +195,7 @@ end
pointwise_logdensities(model::Model, chain::Chains, keytype = String)
Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
with keys corresponding to symbols of the observations, and values being matrices
with keys corresponding to symbols of the variables, and values being matrices
of shape `(num_chains, num_samples)`.
`keytype` specifies what the type of the keys used in the returned `OrderedDict` are.
Expand Down
6 changes: 2 additions & 4 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1116,10 +1116,8 @@ function TestLogModifyingChildContext(
mod, context
)
end
# Samplers call leafcontext(model.context) when evaluating log-densities
# Hence, in order to be used need to say that its a leaf-context
#DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent()
DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsLeaf()

DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context
function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child)
return TestLogModifyingChildContext(context.mod, child)
Expand Down
5 changes: 2 additions & 3 deletions test/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@

# Compute the pointwise loglikelihoods.
lls = pointwise_loglikelihoods(m, vi)
loglikelihood = sum(sum, values(lls))

#if isempty(lls)
if loglikelihood 0.0 #isempty(lls)
if isempty(lls)
# One of the models with literal observations, so we just skip.
# TODO: Think of better way to detect this special case
continue
end

loglikelihood = sum(sum, values(lls))
loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...)

#priors =
Expand Down
28 changes: 13 additions & 15 deletions test/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
prior_context = PriorContext()
mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2)
mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx)
#m = DynamicPPL.TestUtils.DEMO_MODELS[12]
#m = DynamicPPL.TestUtils.DEMO_MODELS[1]
#m = model = DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2()
@testset "$(m.f)" for (i, m) in enumerate(DynamicPPL.TestUtils.DEMO_MODELS)
demo_models = (
DynamicPPL.TestUtils.DEMO_MODELS...,
DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix2())
@testset "$(m.f)" for (i, m) in enumerate(demo_models)
#@show i
example_values = DynamicPPL.TestUtils.rand_prior_true(m)

Expand All @@ -26,23 +29,19 @@
# Compute the pointwise loglikelihoods.
lls = pointwise_logdensities(m, vi, likelihood_context)
#lls2 = pointwise_loglikelihoods(m, vi)
loglikelihood_sum = sum(sum, values(lls))
if loglikelihood_sum 0.0 #isempty(lls)
if isempty(lls)
# One of the models with literal observations, so we just skip.
# TODO: Think of better way to detect this special case
loglikelihood_true = 0.0
else
loglikelihood_sum = sum(sum, values(lls))
@test loglikelihood_sum loglikelihood_true
end
@test loglikelihood_sum loglikelihood_true

# Compute the pointwise logdensities of the priors.
lps_prior = pointwise_logdensities(m, vi, prior_context)
logp = sum(sum, values(lps_prior))
if false # isempty(lps_prior)
# One of the models with only observations so we just skip.
else
logp1 = getlogp(vi)
@test !isfinite(logp_true) || logp logp_true
end
logp1 = getlogp(vi)
@test !isfinite(logp_true) || logp logp_true

# Compute both likelihood and logdensity of prior
# using the default DefaultContex
Expand All @@ -57,7 +56,6 @@
end
end


@testset "pointwise_logdensities chain" begin
@model function demo(x, ::Type{TV}=Vector{Float64}) where {TV}
s ~ InverseGamma(2, 3)
Expand All @@ -73,9 +71,9 @@ end
# generate the sample used below
chain = sample(model, MH(), MCMCThreads(), 10, 2)
arr0 = stack(Array(chain, append_chains=false))
@show(arr0);
@show(arr0[1:2,:,:]);
end
arr0 = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 0.9199555480151707 -0.1304320097505629 1.0669120062696917 -0.05253734412139093; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183; 1.0982766276744311 0.9593277181079177 0.005558174156359029 -0.45842032209694183;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 2.5409470583244933 1.7838744695696407 0.7013562890105632 -3.0843947804314658; 0.8296370582311665 1.5360702767879642 -1.5964695255693102 0.16928084806166913; 2.6246697053824954 0.8096845024785173 -1.2621822861663752 1.1414885535466166; 1.1304261861894538 0.7325784741344005 -1.1866016911837542 -0.1639319562090826; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 2.5669872989791473 -0.43642462460747317 0.07057300935786101 0.5168578624259272; 0.9838526141898173 -0.20198797220982412 2.0569535882007006 -1.1560724118010939]
arr0[1:2, :, :] = [5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317; 5.590726417006858 -3.3407908212996493 -3.5126580698975687 -0.02830755634462317;;; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497; 3.5612802961176797 -5.167692608117693 1.3768066487740864 -0.9154694769223497]
chain = Chains(arr0, [:s, Symbol("m[1]"), Symbol("m[2]"), Symbol("m[3]")]);
tmp1 = pointwise_logdensities(model, chain)
vi = VarInfo(model)
Expand Down

0 comments on commit 656a757

Please sign in to comment.