Skip to content

Commit

Permalink
Improve GibbsContext type stability test
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Nov 12, 2024
1 parent 14decaf commit a7317e8
Showing 1 changed file with 35 additions and 42 deletions.
77 changes: 35 additions & 42 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ has_dot_assume(::DynamicPPL.Model) = true
end
s = sum(y) - sum(z)
obs1 ~ Normal(s, 1)
return obs2 ~ Poisson(y[3])
obs2 ~ Poisson(y[3])
return obs1, obs2, variance, z, y, s
end

model = test_model(1.2, 2, 10, 2.5)
Expand All @@ -68,53 +69,45 @@ has_dot_assume(::DynamicPPL.Model) = true
n -> Combinatorics.combinations(all_varnames, n), 1:length(all_varnames)
),
)
for typed in (true, false)
for target_vns in target_vn_combinations
global_varinfo =
typed ? DynamicPPL.VarInfo(model) : DynamicPPL.untyped_varinfo(model)
target_vns = collect(target_vns)
local_varinfo = DynamicPPL.subset(global_varinfo, target_vns)
ctx = Turing.Inference.GibbsContext(
target_vns, Ref(global_varinfo), Turing.DefaultContext()
)

# Check that the correct varnames are conditioned, and that getting their
# values is type stable when the varinfo is.
for k in keys(global_varinfo)
is_target = any(
Iterators.map(vn -> DynamicPPL.subsumes(vn, k), target_vns)
)
@test Turing.Inference.is_target_varname(ctx, k) == is_target
if !is_target && typed
@inferred Turing.Inference.get_conditioned_gibbs(ctx, k)
end
end
for target_vns in target_vn_combinations
global_varinfo = DynamicPPL.VarInfo(model)
target_vns = collect(target_vns)
local_varinfo = DynamicPPL.subset(global_varinfo, target_vns)
ctx = Turing.Inference.GibbsContext(
target_vns, Ref(global_varinfo), Turing.DefaultContext()
)

# Check the type stability also in the dot_tilde pipeline.
for k in all_varnames
# The map(identity, ...) part is there to concretise the eltype.
subkeys = map(
identity,
filter(vn -> DynamicPPL.subsumes(k, vn), keys(global_varinfo)),
)
is_target = (k in target_vns)
@test Turing.Inference.is_target_varname(ctx, subkeys) == is_target
if !is_target && typed
@inferred Turing.Inference.get_conditioned_gibbs(ctx, subkeys)
end
# Check that the correct varnames are conditioned, and that getting their
# values is type stable when the varinfo is.
for k in keys(global_varinfo)
is_target = any(Iterators.map(vn -> DynamicPPL.subsumes(vn, k), target_vns))
@test Turing.Inference.is_target_varname(ctx, k) == is_target
if !is_target
@inferred Turing.Inference.get_conditioned_gibbs(ctx, k)
end
end

# Check that evaluate!! and the result it returns are type stable.
conditioned_model = DynamicPPL.contextualize(model, ctx)
_, post_eval_varinfo = @inferred DynamicPPL.evaluate!!(
conditioned_model, local_varinfo
# Check the type stability also in the dot_tilde pipeline.
for k in all_varnames
# The map(identity, ...) part is there to concretise the eltype.
subkeys = map(
identity, filter(vn -> DynamicPPL.subsumes(k, vn), keys(global_varinfo))
)
if typed
for k in keys(post_eval_varinfo)
@inferred post_eval_varinfo[k]
end
is_target = (k in target_vns)
@test Turing.Inference.is_target_varname(ctx, subkeys) == is_target
if !is_target
@inferred Turing.Inference.get_conditioned_gibbs(ctx, subkeys)
end
end

# Check that evaluate!! and the result it returns are type stable.
conditioned_model = DynamicPPL.contextualize(model, ctx)
_, post_eval_varinfo = @inferred DynamicPPL.evaluate!!(
conditioned_model, local_varinfo
)
for k in keys(post_eval_varinfo)
@inferred post_eval_varinfo[k]
end
end
end
end
Expand Down

0 comments on commit a7317e8

Please sign in to comment.