diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d7c24fafb..50919e77e 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -77,44 +77,6 @@ function tilde_assume( return tilde_assume(rng, childcontext(context), args...) end -function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(PriorContext(), right, vn, vi) -end -function tilde_assume( - rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi -) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(rng, PriorContext(), sampler, right, vn, vi) -end - -function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(LikelihoodContext(), right, vn, vi) -end -function tilde_assume( - rng::Random.AbstractRNG, - context::LikelihoodContext{<:NamedTuple}, - sampler, - right, - vn, - vi, -) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) -end function tilde_assume(::LikelihoodContext, right, vn, vi) return assume(NoDist(right), vn, vi) end @@ -328,37 +290,6 @@ function dot_tilde_assume( end # `LikelihoodContext` -function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi) - else - dot_tilde_assume(LikelihoodContext(), right, left, vn, vi) - end -end -function dot_tilde_assume( - rng::Random.AbstractRNG, - context::LikelihoodContext{<:NamedTuple}, - sampler, - right, - left, - vn, - vi, -) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi) - else - dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi) - end -end - function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) return dot_assume(nodist(right), left, vn, vi) end @@ -368,38 +299,6 @@ function dot_tilde_assume( return dot_assume(rng, sampler, nodist(right), vn, left, vi) end -# `PriorContext` -function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(PriorContext(), _right, _left, _vns, vi) - else - dot_tilde_assume(PriorContext(), right, left, vn, vi) - end -end -function dot_tilde_assume( - rng::Random.AbstractRNG, - context::PriorContext{<:NamedTuple}, - sampler, - right, - left, - vn, - vi, -) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi) - else - dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi) - end -end - # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi) diff --git a/src/contexts.jl b/src/contexts.jl index 53b454df6..5da4208b5 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -53,7 +53,7 @@ DefaultContext() julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior julia> DynamicPPL.childcontext(ctx_prior) -PriorContext{Nothing}(nothing) +PriorContext() ``` """ setchildcontext @@ -97,7 +97,7 @@ ParentContext(ParentContext(DefaultContext())) julia> # Replace the leaf context with another leaf. leafcontext(setleafcontext(ctx, PriorContext())) -PriorContext{Nothing}(nothing) +PriorContext() julia> # Append another parent context. setleafcontext(ctx, ParentContext(DefaultContext())) @@ -195,32 +195,19 @@ struct DefaultContext <: AbstractContext end NodeTrait(context::DefaultContext) = IsLeaf() """ - struct PriorContext{Tvars} <: AbstractContext - vars::Tvars - end + PriorContext <: AbstractContext -The `PriorContext` enables the computation of the log prior of the parameters `vars` when -running the model. +A leaf context resulting in the exclusion of likelihood terms when running the model. """ -struct PriorContext{Tvars} <: AbstractContext - vars::Tvars -end -PriorContext() = PriorContext(nothing) +struct PriorContext <: AbstractContext end NodeTrait(context::PriorContext) = IsLeaf() """ - struct LikelihoodContext{Tvars} <: AbstractContext - vars::Tvars - end + LikelihoodContext <: AbstractContext -The `LikelihoodContext` enables the computation of the log likelihood of the parameters when -running the model. `vars` can be used to evaluate the log likelihood for specific values -of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. +A leaf context resulting in the exclusion of prior terms when running the model. """ -struct LikelihoodContext{Tvars} <: AbstractContext - vars::Tvars -end -LikelihoodContext() = LikelihoodContext(nothing) +struct LikelihoodContext <: AbstractContext end NodeTrait(context::LikelihoodContext) = IsLeaf() """