diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index fecb9adec..0ee75efb6 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -57,7 +57,8 @@ has_dot_assume(::DynamicPPL.Model) = true end s = sum(y) - sum(z) obs1 ~ Normal(s, 1) - return obs2 ~ Poisson(y[3]) + obs2 ~ Poisson(y[3]) + return obs1, obs2, variance, z, y, s end model = test_model(1.2, 2, 10, 2.5) @@ -68,53 +69,45 @@ has_dot_assume(::DynamicPPL.Model) = true n -> Combinatorics.combinations(all_varnames, n), 1:length(all_varnames) ), ) - for typed in (true, false) - for target_vns in target_vn_combinations - global_varinfo = - typed ? DynamicPPL.VarInfo(model) : DynamicPPL.untyped_varinfo(model) - target_vns = collect(target_vns) - local_varinfo = DynamicPPL.subset(global_varinfo, target_vns) - ctx = Turing.Inference.GibbsContext( - target_vns, Ref(global_varinfo), Turing.DefaultContext() - ) - - # Check that the correct varnames are conditioned, and that getting their - # values is type stable when the varinfo is. - for k in keys(global_varinfo) - is_target = any( - Iterators.map(vn -> DynamicPPL.subsumes(vn, k), target_vns) - ) - @test Turing.Inference.is_target_varname(ctx, k) == is_target - if !is_target && typed - @inferred Turing.Inference.get_conditioned_gibbs(ctx, k) - end - end + for target_vns in target_vn_combinations + global_varinfo = DynamicPPL.VarInfo(model) + target_vns = collect(target_vns) + local_varinfo = DynamicPPL.subset(global_varinfo, target_vns) + ctx = Turing.Inference.GibbsContext( + target_vns, Ref(global_varinfo), Turing.DefaultContext() + ) - # Check the type stability also in the dot_tilde pipeline. - for k in all_varnames - # The map(identity, ...) part is there to concretise the eltype. - subkeys = map( - identity, - filter(vn -> DynamicPPL.subsumes(k, vn), keys(global_varinfo)), - ) - is_target = (k in target_vns) - @test Turing.Inference.is_target_varname(ctx, subkeys) == is_target - if !is_target && typed - @inferred Turing.Inference.get_conditioned_gibbs(ctx, subkeys) - end + # Check that the correct varnames are conditioned, and that getting their + # values is type stable when the varinfo is. + for k in keys(global_varinfo) + is_target = any(Iterators.map(vn -> DynamicPPL.subsumes(vn, k), target_vns)) + @test Turing.Inference.is_target_varname(ctx, k) == is_target + if !is_target + @inferred Turing.Inference.get_conditioned_gibbs(ctx, k) end + end - # Check that evaluate!! and the result it returns are type stable. - conditioned_model = DynamicPPL.contextualize(model, ctx) - _, post_eval_varinfo = @inferred DynamicPPL.evaluate!!( - conditioned_model, local_varinfo + # Check the type stability also in the dot_tilde pipeline. + for k in all_varnames + # The map(identity, ...) part is there to concretise the eltype. + subkeys = map( + identity, filter(vn -> DynamicPPL.subsumes(k, vn), keys(global_varinfo)) ) - if typed - for k in keys(post_eval_varinfo) - @inferred post_eval_varinfo[k] - end + is_target = (k in target_vns) + @test Turing.Inference.is_target_varname(ctx, subkeys) == is_target + if !is_target + @inferred Turing.Inference.get_conditioned_gibbs(ctx, subkeys) end end + + # Check that evaluate!! and the result it returns are type stable. + conditioned_model = DynamicPPL.contextualize(model, ctx) + _, post_eval_varinfo = @inferred DynamicPPL.evaluate!!( + conditioned_model, local_varinfo + ) + for k in keys(post_eval_varinfo) + @inferred post_eval_varinfo[k] + end end end end