Skip to content

Commit

Permalink
undeprecate pointwise_loglikelihoods and implement pointwise_prior_lo…
Browse files Browse the repository at this point in the history
…gdensities

mostly taken from TuringLang#669
  • Loading branch information
bgctw committed Sep 24, 2024
1 parent 656a757 commit 7aa9ebe
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 43 deletions.
4 changes: 2 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@ export AbstractVarInfo,
# Convenience functions
logprior,
logjoint,
pointwise_loglikelihoods,
pointwise_prior_logdensities,
pointwise_logdensities,
pointwise_loglikelihoods,
condition,
decondition,
fix,
Expand Down Expand Up @@ -190,7 +191,6 @@ include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")
include("values_as_in_model.jl")
include("deprecated.jl")

include("debug_utils.jl")
using .DebugUtils
Expand Down
9 changes: 0 additions & 9 deletions src/deprecated.jl

This file was deleted.

59 changes: 59 additions & 0 deletions src/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,62 @@ function pointwise_logdensities(model::Model,
end




"""
pointwise_loglikelihoods(model, chain[, keytype, context])
Compute the pointwise log-likelihoods of the model given the chain.
This is the same as `pointwise_logdensities(model, chain, context)`, but only
including the likelihood terms.
See also: [`pointwise_logdensities`](@ref).
"""
function pointwise_loglikelihoods(
model::Model,
chain,
keytype::Type{T}=String,
context::AbstractContext=LikelihoodContext(),
) where {T}
if !(leafcontext(context) isa LikelihoodContext)
throw(ArgumentError("Leaf context should be a LikelihoodContext"))
end

return pointwise_logdensities(model, chain, T, context)
end

function pointwise_loglikelihoods(
model::Model, varinfo::AbstractVarInfo, context::AbstractContext=LikelihoodContext()
)
if !(leafcontext(context) isa LikelihoodContext)
throw(ArgumentError("Leaf context should be a LikelihoodContext"))
end

return pointwise_logdensities(model, varinfo, context)
end

"""
pointwise_prior_logdensities(model, chain[, keytype, context])
Compute the pointwise log-prior-densities of the model given the chain.
This is the same as `pointwise_logdensities(model, chain, context)`, but only
including the prior terms.
See also: [`pointwise_logdensities`](@ref).
"""
function pointwise_prior_logdensities(
model::Model, chain, keytype::Type{T}=String, context::AbstractContext=PriorContext()
) where {T}
if !(leafcontext(context) isa PriorContext)
throw(ArgumentError("Leaf context should be a PriorContext"))
end

return pointwise_logdensities(model, chain, T, context)
end

function pointwise_prior_logdensities(
model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext()
)
if !(leafcontext(context) isa PriorContext)
throw(ArgumentError("Leaf context should be a PriorContext"))
end

return pointwise_logdensities(model, varinfo, context)
end

28 changes: 0 additions & 28 deletions test/deprecated.jl

This file was deleted.

5 changes: 3 additions & 2 deletions test/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
logp_true = logprior(m, vi)

# Compute the pointwise loglikelihoods.
lls = pointwise_logdensities(m, vi, likelihood_context)
lls = pointwise_loglikelihoods(m, vi)
#lls2 = pointwise_loglikelihoods(m, vi)
if isempty(lls)
# One of the models with literal observations, so we just skip.
Expand All @@ -38,7 +38,7 @@
end

# Compute the pointwise logdensities of the priors.
lps_prior = pointwise_logdensities(m, vi, prior_context)
lps_prior = pointwise_prior_logdensities(m, vi)
logp = sum(sum, values(lps_prior))
logp1 = getlogp(vi)
@test !isfinite(logp_true) || logp logp_true
Expand All @@ -56,6 +56,7 @@
end
end


@testset "pointwise_logdensities chain" begin
@model function demo(x, ::Type{TV}=Vector{Float64}) where {TV}
s ~ InverseGamma(2, 3)
Expand Down
2 changes: 0 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ include("test_util.jl")
include("pointwise_logdensities.jl")

include("lkj.jl")

include("deprecated.jl")
end

@testset "compat" begin
Expand Down

0 comments on commit 7aa9ebe

Please sign in to comment.