From 29a6c7ec0cd62d3a4d1dc18a304d5e4d1e024cfb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 27 Jan 2025 15:23:49 +0000 Subject: [PATCH] Handle nested PrefixContext (#787) * Prefix varnames appropriately inside check_model_and_trace * Fix values_as_in_model as well * Add test for check_model with manual prefix * Add values_as_in_model tests * Add tests for prefix nesting * Bump Project.toml --- Project.toml | 2 +- src/contexts.jl | 15 +++++++++------ src/debug_utils.jl | 23 ++++++++++++----------- src/values_as_in_model.jl | 2 +- test/contexts.jl | 20 ++++++++++++++++++++ test/debug_utils.jl | 9 +++++++++ test/model.jl | 21 +++++++++++++++++++++ 7 files changed, 73 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index bd553c0cc..3df611824 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.34.1" +version = "0.34.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/contexts.jl b/src/contexts.jl index a9470fbb6..99b2136f3 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -261,6 +261,7 @@ end const PREFIX_SEPARATOR = Symbol(".") +# TODO(penelopeysm): Prefixing arguably occurs the wrong way round here function PrefixContext{PrefixInner}( context::PrefixContext{PrefixOuter} ) where {PrefixInner,PrefixOuter} @@ -273,13 +274,15 @@ function PrefixContext{PrefixInner}( end end -function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} - if @generated - return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getoptic(vn))) - else - VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getoptic(vn)) - end +# TODO(penelopeysm): Prefixing arguably occurs the wrong way round here +function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} + return prefix( + childcontext(ctx), VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getoptic(vn)) + ) end +prefix(ctx::AbstractContext, vn::VarName) = prefix(NodeTrait(ctx), ctx, vn) +prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn +prefix(::IsParent, ctx::AbstractContext, vn::VarName) = prefix(childcontext(ctx), vn) """ prefix(model::Model, x) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index f486482a9..43b5054d5 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -239,42 +239,43 @@ function DynamicPPL.setchildcontext(context::DebugContext, child) end function record_varname!(context::DebugContext, varname::VarName, dist) - if haskey(context.varnames_seen, varname) + prefixed_varname = prefix(context, varname) + if haskey(context.varnames_seen, prefixed_varname) if context.error_on_failure - error("varname $varname used multiple times in model") + error("varname $prefixed_varname used multiple times in model") else - @warn "varname $varname used multiple times in model" + @warn "varname $prefixed_varname used multiple times in model" end - context.varnames_seen[varname] += 1 + context.varnames_seen[prefixed_varname] += 1 else # We need to check: # 1. Does this `varname` subsume any of the other keys. # 2. Does any of the other keys subsume `varname`. vns = collect(keys(context.varnames_seen)) # Is `varname` subsumed by any of the other keys? - idx_parent = findfirst(Base.Fix2(subsumes, varname), vns) + idx_parent = findfirst(Base.Fix2(subsumes, prefixed_varname), vns) if idx_parent !== nothing varname_parent = vns[idx_parent] if context.error_on_failure error( - "varname $(varname_parent) used multiple times in model (subsumes $varname)", + "varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)", ) else - @warn "varname $(varname_parent) used multiple times in model (subsumes $varname)" + @warn "varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)" end # Update count of parent. context.varnames_seen[varname_parent] += 1 else # Does `varname` subsume any of the other keys? - idx_child = findfirst(Base.Fix1(subsumes, varname), vns) + idx_child = findfirst(Base.Fix1(subsumes, prefixed_varname), vns) if idx_child !== nothing varname_child = vns[idx_child] if context.error_on_failure error( - "varname $(varname_child) used multiple times in model (subsumed by $varname)", + "varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)", ) else - @warn "varname $(varname_child) used multiple times in model (subsumed by $varname)" + @warn "varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)" end # Update count of child. @@ -282,7 +283,7 @@ function record_varname!(context::DebugContext, varname::VarName, dist) end end - context.varnames_seen[varname] = 1 + context.varnames_seen[prefixed_varname] = 1 end end diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index ca8cc1cb3..4cef5fa4e 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -45,7 +45,7 @@ is_extracting_values(::IsParent, ::AbstractContext) = false is_extracting_values(::IsLeaf, ::AbstractContext) = false function Base.push!(context::ValuesAsInModelContext, vn::VarName, value) - return setindex!(context.values, copy(value), vn) + return setindex!(context.values, copy(value), prefix(context, vn)) end function broadcast_push!(context::ValuesAsInModelContext, vns, values) diff --git a/test/contexts.jl b/test/contexts.jl index dd3b4c90c..ef55335d0 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -162,6 +162,26 @@ end @test getoptic(vn_prefixed) === getoptic(vn) end + @testset "nested within arbitrary context stacks" begin + vn = @varname(x[1]) + ctx1 = PrefixContext{:a}(DefaultContext()) + ctx2 = SamplingContext(ctx1) + ctx3 = PrefixContext{:b}(ctx2) + ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) + vn_prefixed1 = prefix(ctx1, vn) + vn_prefixed2 = prefix(ctx2, vn) + vn_prefixed3 = prefix(ctx3, vn) + vn_prefixed4 = prefix(ctx4, vn) + @test DynamicPPL.getsym(vn_prefixed1) == Symbol("a.x") + @test DynamicPPL.getsym(vn_prefixed2) == Symbol("a.x") + @test DynamicPPL.getsym(vn_prefixed3) == Symbol("a.b.x") + @test DynamicPPL.getsym(vn_prefixed4) == Symbol("a.b.x") + @test DynamicPPL.getoptic(vn_prefixed1) === DynamicPPL.getoptic(vn) + @test DynamicPPL.getoptic(vn_prefixed2) === DynamicPPL.getoptic(vn) + @test DynamicPPL.getoptic(vn_prefixed3) === DynamicPPL.getoptic(vn) + @test DynamicPPL.getoptic(vn_prefixed4) === DynamicPPL.getoptic(vn) + end + context = DynamicPPL.PrefixContext{:prefix}(SamplingContext()) @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS # Sample with the context. diff --git a/test/debug_utils.jl b/test/debug_utils.jl index 294364758..d4f6601f5 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -60,6 +60,15 @@ end model = ModelOuterWorking() @test check_model(model; error_on_failure=true) + + # With manual prefixing, https://github.com/TuringLang/DynamicPPL.jl/issues/785 + @model function ModelOuterWorking2() + x1 ~ to_submodel(prefix(ModelInner(), :a), false) + x2 ~ to_submodel(prefix(ModelInner(), :b), false) + return (x1, x2) + end + model = ModelOuterWorking2() + @test check_model(model; error_on_failure=true) end @testset "subsumes (x then x[1])" begin diff --git a/test/model.jl b/test/model.jl index 45c770cc4..118f60a40 100644 --- a/test/model.jl +++ b/test/model.jl @@ -429,6 +429,27 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end end + + @testset "Prefixing" begin + @model inner() = x ~ Normal() + + @model function outer_auto_prefix() + a ~ to_submodel(inner(), true) + b ~ to_submodel(inner(), true) + return nothing + end + @model function outer_manual_prefix() + a ~ to_submodel(prefix(inner(), :a), false) + b ~ to_submodel(prefix(inner(), :b), false) + return nothing + end + + for model in (outer_auto_prefix(), outer_manual_prefix()) + vi = VarInfo(model) + vns = Set(keys(values_as_in_model(model, false, vi))) + @test vns == Set([@varname(var"a.x"), @varname(var"b.x")]) + end + end end @testset "Erroneous model call" begin