Skip to content

Commit

Permalink
Add multi-variate enumeration test case.
Browse files Browse the repository at this point in the history
  • Loading branch information
ztangent committed Oct 21, 2024
1 parent 7104004 commit 1146ca0
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/inference/enumerative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ function expand_grid_spec_to_values(
end
return vs
end
return ((addr, v) for v in Iterators.product(vals...))
return ((addr, collect(v)) for v in Iterators.product(vals...))
else
error("Support must be :discrete or :continuous")
end
Expand Down
56 changes: 56 additions & 0 deletions test/inference/enumerative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,60 @@
@test map_trace[(:coeff, 1)] == 0.1
@test map_trace[(:coeff, 2)] == -0.5

# 2D mixture of normals
@gen function mixture_model()
sign ~ bernoulli(0.5)
mu = sign ? fill(0.5, 2) : fill(-0.5, 2)
z ~ broadcasted_normal(mu, ones(2))
end

# test construction of grid with 2D random variable
grid = choice_vol_grid(
(:sign, [false, true]),
(:z, (-2.0:0.1:2.0, -2.0:0.1:2.0), :continuous, Val(2)),
anchor = :left
)

@test size(grid) == (2, 40, 40)
@test length(grid) == 3200

choices, log_vol = first(grid)
@test choices == choicemap((:sign, false), (:z, [-2.0, -2.0]))
@test log_vol log(0.1^2)

@test all(all([-2.0, -2.0] .<= choices[:z] .<= [2.0, 2.0]) for (choices, _) in grid)
@test all(log_vol log(0.1^2) for (_, log_vol) in grid)

# run enumerative inference over grid
traces, log_norm_weights, lml_est =
enumerative_inference(mixture_model, (), choicemap(), grid)

@test size(traces) == (2, 40, 40)
@test length(traces) == 3200
@test all(all([-2.0, -2.0] .<= tr[:z] .<= [2.0, 2.0]) for tr in traces)

# test that log-weights are as expected
function expected_logpdf(tr)
x, y = tr[:z]
mu = tr[:sign] ? 0.5 : -0.5
return log(0.5) + logpdf(normal, x, mu, 1.0) + logpdf(normal, y, mu, 1.0)
end

log_joint_weights = [expected_logpdf(tr) + log(0.1^2) for tr in traces]
lml_expected = logsumexp(log_joint_weights)
@test lml_est lml_expected
@test all((jw - lml_expected) w for (jw, w) in zip(log_joint_weights, log_norm_weights))

# test that maximal log-weights are at modes
max_log_weight = maximum(log_norm_weights)
max_idxs = findall(log_norm_weights .== max_log_weight)

max_trace_1 = traces[max_idxs[1]]
@test max_trace_1[:sign] == false
@test max_trace_1[:z] == [-0.5, -0.5]

max_trace_2 = traces[max_idxs[2]]
@test max_trace_2[:sign] == true
@test max_trace_2[:z] == [0.5, 0.5]

end

0 comments on commit 1146ca0

Please sign in to comment.