diff --git a/Project.toml b/Project.toml index 60dbcdc81..2bf60214f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.33.0" +version = "0.33.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/api.md b/docs/src/api.md index d5c6bd690..093cb06a6 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -65,7 +65,7 @@ DynamicPPL.LogDensityFunction A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref). ```@docs -|(::Model, ::Any) +|(::Model, ::Union{Tuple,NamedTuple,AbstractDict{<:VarName}}) condition DynamicPPL.conditioned ``` @@ -403,6 +403,7 @@ LikelihoodContext PriorContext MiniBatchContext PrefixContext +ConditionContext ``` ### Samplers diff --git a/src/contexts.jl b/src/contexts.jl index b337e4750..a9470fbb6 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -309,7 +309,20 @@ function prefix(model::Model, ::Val{x}) where {x} return contextualize(model, PrefixContext{Symbol(x)}(model.context)) end -struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext +""" + + ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} + +Model context that contains values that are to be conditioned on. The values +can either be a NamedTuple mapping symbols to values, such as `(a=1, b=2)`, or +an AbstractDict mapping varnames to values (e.g. `Dict(@varname(a) => 1, +@varname(b) => 2)`). The former is more performant, but the latter must be used +when there are varnames that cannot be represented as symbols, e.g. +`@varname(x[1])`. +""" +struct ConditionContext{ + Values<:Union{NamedTuple,AbstractDict{<:VarName}},Ctx<:AbstractContext +} <: AbstractContext values::Values context::Ctx end @@ -317,12 +330,19 @@ end const NamedConditionContext{Names} = ConditionContext{<:NamedTuple{Names}} const DictConditionContext = ConditionContext{<:AbstractDict} -ConditionContext(values) = ConditionContext(values, DefaultContext()) - -# Try to avoid nested `ConditionContext`. +# Use DefaultContext as the default base context +function ConditionContext(values::Union{NamedTuple,AbstractDict}) + return ConditionContext(values, DefaultContext()) +end +# Optimisation when there are no values to condition on +ConditionContext(::NamedTuple{()}, context::AbstractContext) = context +# Collapse consecutive levels of `ConditionContext`. Note that this overrides +# values inside the child context, thus giving precedence to the outermost +# `ConditionContext`. function ConditionContext(values::NamedTuple, context::NamedConditionContext) - # Note that this potentially overrides values from `context`, thus giving - # precedence to the outmost `ConditionContext`. + return ConditionContext(merge(context.values, values), childcontext(context)) +end +function ConditionContext(values::AbstractDict{<:VarName}, context::DictConditionContext) return ConditionContext(merge(context.values, values), childcontext(context)) end @@ -399,43 +419,6 @@ function getconditioned_nested(::IsParent, context, vn) end end -""" - condition([context::AbstractContext,] values::NamedTuple) - condition([context::AbstractContext]; values...) - -Return `ConditionContext` with `values` and `context` if `values` is non-empty, -otherwise return `context` which is [`DefaultContext`](@ref) by default. - -See also: [`decondition`](@ref) -""" -AbstractPPL.condition(; values...) = condition(NamedTuple(values)) -AbstractPPL.condition(values::NamedTuple) = condition(DefaultContext(), values) -function AbstractPPL.condition(value::Pair{<:VarName}, values::Pair{<:VarName}...) - return condition((value, values...)) -end -function AbstractPPL.condition(values::NTuple{<:Any,<:Pair{<:VarName}}) - return condition(DefaultContext(), values) -end -AbstractPPL.condition(context::AbstractContext, values::NamedTuple{()}) = context -function AbstractPPL.condition( - context::AbstractContext, values::Union{AbstractDict,NamedTuple} -) - return ConditionContext(values, context) -end -function AbstractPPL.condition(context::AbstractContext; values...) - return condition(context, NamedTuple(values)) -end -function AbstractPPL.condition( - context::AbstractContext, value::Pair{<:VarName}, values::Pair{<:VarName}... -) - return condition(context, (value, values...)) -end -function AbstractPPL.condition( - context::AbstractContext, values::NTuple{<:Any,Pair{<:VarName}} -) - return condition(context, Dict(values)) -end - """ decondition(context::AbstractContext, syms...) @@ -445,41 +428,34 @@ Note that this recursively traverses contexts, deconditioning all along the way. See also: [`condition`](@ref) """ -AbstractPPL.decondition(::IsLeaf, context, args...) = context -function AbstractPPL.decondition(::IsParent, context, args...) - return setchildcontext(context, decondition(childcontext(context), args...)) +decondition_context(::IsLeaf, context, args...) = context +function decondition_context(::IsParent, context, args...) + return setchildcontext(context, decondition_context(childcontext(context), args...)) end -function AbstractPPL.decondition(context, args...) - return decondition(NodeTrait(context), context, args...) +function decondition_context(context, args...) + return decondition_context(NodeTrait(context), context, args...) end -function AbstractPPL.decondition(context::ConditionContext) - return decondition(childcontext(context)) -end -function AbstractPPL.decondition(context::ConditionContext, sym) - return condition( - decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym) - ) +function decondition_context(context::ConditionContext) + return decondition_context(childcontext(context)) end -function AbstractPPL.decondition(context::ConditionContext, sym, syms...) - return decondition( - condition( - decondition(childcontext(context), syms...), - BangBang.delete!!(context.values, sym), - ), - syms..., - ) -end - -function AbstractPPL.decondition( - context::NamedConditionContext, vn::VarName{sym} -) where {sym} - return condition( - decondition(childcontext(context), vn), BangBang.delete!!(context.values, sym) - ) +function decondition_context(context::ConditionContext, sym, syms...) + new_values = deepcopy(context.values) + for s in (sym, syms...) + new_values = BangBang.delete!!(new_values, s) + end + return if length(new_values) == 0 + # No more values left, can unwrap + decondition_context(childcontext(context), syms...) + else + ConditionContext( + new_values, decondition_context(childcontext(context), sym, syms...) + ) + end end -function AbstractPPL.decondition(context::ConditionContext, vn::VarName) - return condition( - decondition(childcontext(context), vn), BangBang.delete!!(context.values, vn) +function decondition_context(context::NamedConditionContext, vn::VarName{sym}) where {sym} + return ConditionContext( + BangBang.delete!!(context.values, sym), + decondition_context(childcontext(context), vn), ) end diff --git a/src/model.jl b/src/model.jl index 2bad6f1fe..6fb0b40b0 100644 --- a/src/model.jl +++ b/src/model.jl @@ -96,7 +96,8 @@ Return a `Model` which now treats variables on the right-hand side as observatio See [`condition`](@ref) for more information and examples. """ -Base.:|(model::Model, values) = condition(model, values) +Base.:|(model::Model, values::Union{Pair,Tuple,NamedTuple,AbstractDict{<:VarName}}) = + condition(model, values) """ condition(model::Model; values...) @@ -264,11 +265,32 @@ julia> conditioned_model_dict() 1.0 ``` """ -AbstractPPL.condition(model::Model; values...) = condition(model, NamedTuple(values)) -function AbstractPPL.condition(model::Model, value, values...) - return contextualize(model, condition(model.context, value, values...)) +function AbstractPPL.condition(model::Model, values...) + # Positional arguments - need to handle cases carefully + return contextualize( + model, ConditionContext(_make_conditioning_values(values...), model.context) + ) +end +function AbstractPPL.condition(model::Model; values...) + # Keyword arguments -- just convert to a NamedTuple + return contextualize(model, ConditionContext(NamedTuple(values), model.context)) end +""" + _make_conditioning_values(vals...) + +Convert different types of input to either a `NamedTuple` or `AbstractDict` of +conditioning values, suitable for storage in a `ConditionContext`. + +This handles all the cases where `vals` is either already a NamedTuple or +AbstractDict (e.g. `model | (x=1, y=2)`), as well as if they are splatted (e.g. +`condition(model, x=1, y=2)`). +""" +_make_conditioning_values(values::Union{NamedTuple,AbstractDict}) = values +_make_conditioning_values(values::Tuple{Pair{<:VarName}}) = Dict(values) +_make_conditioning_values(v::Pair{<:Symbol}, vs::Pair{<:Symbol}...) = NamedTuple(v, vs...) +_make_conditioning_values(v::Pair{<:VarName}, vs::Pair{<:VarName}...) = Dict(v, vs...) + """ decondition(model::Model) decondition(model::Model, variables...) @@ -379,7 +401,7 @@ true ``` """ function AbstractPPL.decondition(model::Model, syms...) - return contextualize(model, decondition(model.context, syms...)) + return contextualize(model, decondition_context(model.context, syms...)) end """ @@ -413,7 +435,7 @@ julia> # Returns all the variables we have conditioned on + their values. (x = 100.0, m = 1.0) julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). - cm = condition(contextualize(m, PrefixContext{:a}(condition(m=1.0))), x=100.0); + cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0); julia> conditioned(cm) (x = 100.0, m = 1.0) @@ -425,7 +447,7 @@ julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed a.m julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. - cm = condition(contextualize(m, PrefixContext{:a}(condition(var"a.m"=1.0))), x=100.0); + cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((var"a.m"=1.0,)))), x=100.0); julia> conditioned(cm).x 100.0 @@ -433,7 +455,7 @@ julia> conditioned(cm).x julia> conditioned(cm).var"a.m" 1.0 -julia> keys(VarInfo(cm)) # <= no variables are sampled +julia> keys(VarInfo(cm)) # No variables are sampled VarName[] ``` """ diff --git a/test/contexts.jl b/test/contexts.jl index 7a7826466..dd3b4c90c 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -11,6 +11,7 @@ using DynamicPPL: PointwiseLogdensityContext, contextual_isassumption, ConditionContext, + decondition_context, hasconditioned, getconditioned, hasconditioned_nested, @@ -196,6 +197,88 @@ end @test EnzymeCore.EnzymeRules.inactive_type(typeof(context)) end + @testset "ConditionContext" begin + @testset "Nesting" begin + @testset "NamedTuple" begin + n1 = (x=1, y=2) + n2 = (x=3,) + # Values from outer context should override inner one + ctx1 = ConditionContext(n1, ConditionContext(n2)) + @test ctx1.values == (x=1, y=2) + # Check that the two ConditionContexts are collapsed + @test childcontext(ctx1) isa DefaultContext + # Then test the nesting the other way round + ctx2 = ConditionContext(n2, ConditionContext(n1)) + @test ctx2.values == (x=3, y=2) + @test childcontext(ctx2) isa DefaultContext + end + + @testset "Dict" begin + # Same tests as NamedTuple above + d1 = Dict(@varname(x) => 1, @varname(y) => 2) + d2 = Dict(@varname(x) => 3) + ctx1 = ConditionContext(d1, ConditionContext(d2)) + @test ctx1.values == Dict(@varname(x) => 1, @varname(y) => 2) + @test childcontext(ctx1) isa DefaultContext + ctx2 = ConditionContext(d2, ConditionContext(d1)) + @test ctx2.values == Dict(@varname(x) => 3, @varname(y) => 2) + @test childcontext(ctx2) isa DefaultContext + end + end + + @testset "decondition_context" begin + @testset "NamedTuple" begin + ctx = ConditionContext((x=1, y=2, z=3)) + # Decondition all variables + @test decondition_context(ctx) isa DefaultContext + # Decondition only some variables + dctx = decondition_context(ctx, :x) + @test dctx isa ConditionContext + @test dctx.values == (y=2, z=3) + dctx = decondition_context(ctx, :y, :z) + @test dctx isa ConditionContext + @test dctx.values == (x=1,) + # Decondition all variables manually + @test decondition_context(ctx, :x, :y, :z) isa DefaultContext + end + + @testset "Dict" begin + ctx = ConditionContext( + Dict(@varname(x) => 1, @varname(y) => 2, @varname(z) => 3) + ) + # Decondition all variables + @test decondition_context(ctx) isa DefaultContext + # Decondition only some variables + dctx = decondition_context(ctx, @varname(x)) + @test dctx isa ConditionContext + @test dctx.values == Dict(@varname(y) => 2, @varname(z) => 3) + dctx = decondition_context(ctx, @varname(y), @varname(z)) + @test dctx isa ConditionContext + @test dctx.values == Dict(@varname(x) => 1) + # Decondition all variables manually + @test decondition_context(ctx, @varname(x), @varname(y), @varname(z)) isa + DefaultContext + end + + @testset "Nesting" begin + ctx = ConditionContext( + (x=1, y=2), ConditionContext(Dict(@varname(a) => 3, @varname(b) => 4)) + ) + # Decondition an outer variable + dctx = decondition_context(ctx, :x) + @test dctx.values == (y=2,) + @test childcontext(dctx).values == Dict(@varname(a) => 3, @varname(b) => 4) + # Decondition an inner variable + dctx = decondition_context(ctx, @varname(a)) + @test dctx.values == (x=1, y=2) + @test childcontext(dctx).values == Dict(@varname(b) => 4) + # Try deconditioning everything + dctx = decondition_context(ctx) + @test dctx isa DefaultContext + end + end + end + @testset "FixedContext" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS retval = model() diff --git a/test/model.jl b/test/model.jl index 96c0f1560..eb8d6a932 100644 --- a/test/model.jl +++ b/test/model.jl @@ -100,6 +100,39 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end + @testset "model de/conditioning" begin + @model function demo_condition() + x ~ Normal() + return y ~ Normal(x) + end + model = demo_condition() + + # Test that different syntaxes work and give the same underlying ConditionContext + @testset "conditioning NamedTuple" begin + expected_values = (y=2,) + @test condition(model, (y=2,)).context.values == expected_values + @test condition(model; y=2).context.values == expected_values + @test condition(model; y=2).context.values == expected_values + @test (model | (y=2,)).context.values == expected_values + conditioned_model = condition(model, (y=2,)) + @test keys(VarInfo(conditioned_model)) == [@varname(x)] + end + @testset "conditioning AbstractDict" begin + expected_values = Dict(@varname(y) => 2) + @test condition(model, Dict(@varname(y) => 2)).context.values == expected_values + @test condition(model, @varname(y) => 2).context.values == expected_values + @test (model | (@varname(y) => 2,)).context.values == expected_values + conditioned_model = condition(model, Dict(@varname(y) => 2)) + @test keys(VarInfo(conditioned_model)) == [@varname(x)] + end + + @testset "deconditioning" begin + conditioned_model = condition(model, (y=2,)) + deconditioned_model = decondition(conditioned_model) + @test keys(VarInfo(deconditioned_model)) == [@varname(x), @varname(y)] + end + end + @testset "DynamicPPL#684: threadsafe evaluation with multiple types" begin @model function multiple_types(x) ns ~ filldist(Normal(0, 2.0), 3)