-
Notifications
You must be signed in to change notification settings - Fork 162
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add enumerative inference function and test cases.
- Loading branch information
Showing
3 changed files
with
212 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
""" | ||
(traces, log_norm_weights, lml_est) = enumerative_inference( | ||
model::GenerativeFunction, model_args::Tuple, | ||
observations::ChoiceMap, choice_vol_iter | ||
) | ||
Run enumerative inference over a `model`, given `observations` and an iterator over | ||
choice maps and their associated volumes (`choice_vol_iter`), specifying the choices | ||
to be iterated over. Return an array of traces and associated log-weights with the | ||
same shape as `choice_vol_iter`. The log-weight of each trace is normalized, and | ||
corresponds to the log probablity of the volume of sample space that the trace | ||
represents. Also return an estimate of the log marginal likelihood of the | ||
observations (`lml_est`). | ||
All addresses in the `observations` choice map must be sampled by the model when | ||
given the model arguments. The same constraint applies to choice maps enumerated | ||
over by `choice_vol_iter`, which must also avoid sharing addresses with the | ||
`observations`. An iterator over a grid of choice maps and volumes can be constructed | ||
with [`choice_vol_grid`](@ref). | ||
""" | ||
function enumerative_inference( | ||
model::GenerativeFunction{T,U}, model_args::Tuple, | ||
observations::ChoiceMap, choice_vol_iter::I | ||
) where {T,U,I} | ||
if Base.IteratorSize(I) isa Base.HasShape | ||
traces = Array{U}(undef, size(choice_vol_iter)) | ||
log_weights = Array{Float64}(undef, size(choice_vol_iter)) | ||
elseif Base.IteratorSize(I) isa Base.HasLength | ||
traces = Vector{U}(undef, length(choice_vol_iter)) | ||
log_weights = Vector{Float64}(undef, length(choice_vol_iter)) | ||
else | ||
choice_vol_iter = collect(choice_vol_iter) | ||
traces = Vector{U}(undef, length(choice_vol_iter)) | ||
log_weights = Vector{Float64}(undef, length(choice_vol_iter)) | ||
end | ||
for (i, (choices, vol)) in enumerate(choice_vol_iter) | ||
constraints = merge(observations, choices) | ||
(traces[i], log_weight) = generate(model, model_args, constraints) | ||
log_weights[i] = log_weight + log(vol) | ||
end | ||
log_total_weight = logsumexp(log_weights) | ||
log_normalized_weights = log_weights .- log_total_weight | ||
return (traces, log_normalized_weights, log_total_weight) | ||
end | ||
|
||
""" | ||
choice_vol_grid((addr, vals, [support, dims])::Tuple...; anchor=:midpoint) | ||
Given tuples of the form `(addr, vals, [support, dims])`, construct an iterator | ||
over tuples of the form `(choices::ChoiceMap, vol::Real)` via grid enumeration. | ||
Each `addr` is an address of a random choice, and `vals` are the corresponding | ||
values or intervals to enumerate over. The (optional) `support` denotes whether | ||
each random choice is `:discrete` (default) or `:continuous`. This controls how | ||
the grid is constructed: | ||
- `support = :discrete`: The grid iterates over each value in `vals`. | ||
- `support = :continuous` and `dims == Val(1)`: The grid iterates over the | ||
anchors of 1D intervals whose endpoints are given by `vals`. | ||
- `support = :continuous` and `dims == Val(N)` where `N` > 1: The grid iterates | ||
over the anchors of multi-dimensional regions defined `vals`, which is a tuple | ||
of interval endpoints for each dimension. | ||
Continuous choices are assumed to have `dims = Val(1)` dimensions by default. | ||
The `anchor` keyword argument controls which point in each interval is used as | ||
the anchor (`:left`, `:right`, or `:midpoint`). | ||
The volume `vol` associated with each set of `choices` in the grid is given by | ||
the product of the volumes of each continuous region used to construct those | ||
choices. If all addresses enumerated over are `:discrete`, then `vol = 1.0`. | ||
""" | ||
function choice_vol_grid(grid_specs::Tuple...; anchor::Symbol=:midpoint) | ||
val_iter = (expand_grid_spec_to_values(spec...; anchor=anchor) | ||
for spec in grid_specs) | ||
val_iter = Iterators.product(val_iter...) | ||
vol_iter = (expand_grid_spec_to_volumes(spec...) for spec in grid_specs) | ||
vol_iter = Iterators.product(vol_iter...) | ||
choice_vol_iter = Iterators.map(zip(val_iter, vol_iter)) do (vals, vols) | ||
return (choicemap(vals...), prod(vols)) | ||
end | ||
return choice_vol_iter | ||
end | ||
|
||
function expand_grid_spec_to_values( | ||
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1); | ||
anchor::Symbol = :midpoint | ||
) where {N} | ||
if support == :discrete | ||
return ((addr, v) for v in vals) | ||
elseif support == :continuous && N == 1 | ||
if anchor == :left | ||
vals = @view(vals[begin:end-1]) | ||
elseif anchor == :right | ||
vals = @view(vals[begin+1:end]) | ||
else | ||
vals = @view(vals[begin:end-1]) .+ (diff(vals) ./ 2) | ||
end | ||
return ((addr, v) for v in vals) | ||
elseif support == :continuous && N > 1 | ||
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`" | ||
vals = map(vals) do vs | ||
if anchor == :left | ||
vs = @view(vs[begin:end-1]) | ||
elseif anchor == :right | ||
vs = @view(vs[begin+1:end]) | ||
else | ||
vs = @view(vs[begin:end-1]) .+ (diff(vs) ./ 2) | ||
end | ||
return vs | ||
end | ||
return ((addr, v) for v in Iterators.product(vals...)) | ||
else | ||
error("Support must be :discrete or :continuous") | ||
end | ||
end | ||
|
||
function expand_grid_spec_to_volumes( | ||
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1) | ||
) where {N} | ||
if support == :discrete | ||
return ones(length(vals)) | ||
elseif support == :continuous && N == 1 | ||
return diff(vals) | ||
elseif support == :continuous && N > 1 | ||
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`" | ||
diffs = Iterators.product((diff(vs) for vs in vals)...) | ||
return (prod(ds) for ds in diffs) | ||
else | ||
error("Support must be :discrete or :continuous") | ||
end | ||
end | ||
|
||
export enumerative_inference, choice_vol_grid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
@testset "enumerative inference" begin | ||
|
||
# polynomial regression model | ||
@gen function poly_model(n::Int, xs) | ||
degree ~ uniform_discrete(1, n) | ||
coeffs = zeros(n+1) | ||
for d in 0:n | ||
coeffs[d+1] = {(:coeff, d)} ~ uniform(-1, 1) | ||
end | ||
ys = zeros(length(xs)) | ||
for (i, x) in enumerate(xs) | ||
x_powers = x .^ (0:n) | ||
y_mean = sum(coeffs[d+1] * x_powers[d+1] for d in 0:degree) | ||
ys[i] = {(:y, i)} ~ normal(y_mean, 0.1) | ||
end | ||
return ys | ||
end | ||
|
||
# synthetic dataset | ||
coeffs = [0.5, 0.1, -0.5] | ||
xs = collect(0.5:0.5:3.0) | ||
ys = [(coeffs' * [x .^ d for d in 0:2]) for x in xs] | ||
|
||
observations = choicemap() | ||
for (i, y) in enumerate(ys) | ||
observations[(:y, i)] = y | ||
end | ||
|
||
# test construction of choicemap-volume grid | ||
grid = choice_vol_grid( | ||
(:degree, 1:2), | ||
((:coeff, 0), -1:0.2:1, :continuous), | ||
((:coeff, 1), -1:0.2:1, :continuous), | ||
((:coeff, 2), -1:0.2:1, :continuous), | ||
anchor = :midpoint | ||
) | ||
|
||
@test size(grid) == (2, 10, 10, 10) | ||
@test length(grid) == 2000 | ||
|
||
choices, vol = first(grid) | ||
@test choices == choicemap( | ||
(:degree, 1), | ||
((:coeff, 0), -0.9), ((:coeff, 1), -0.9), ((:coeff, 2), -0.9), | ||
) | ||
@test vol ≈ 0.2 * 0.2 * 0.2 | ||
|
||
test_choices(n::Int, cs) = | ||
cs[:degree] in 1:n && all(-1.0 <= cs[(:coeff, d)] <= 1.0 for d in 1:n) | ||
|
||
@test all(test_choices(2, choices) for (choices, vol) in grid) | ||
@test all(vol ≈ (0.2 * 0.2 * 0.2) for (choices, vol) in grid) | ||
|
||
# run enumerative inference over grid | ||
traces, log_norm_weights, lml_est = | ||
enumerative_inference(poly_model, (2, xs), observations, grid) | ||
|
||
@test size(traces) == (2, 10, 10, 10) | ||
@test length(traces) == 2000 | ||
@test all(test_choices(2, tr) for tr in traces) | ||
|
||
# test that log-weights are as expected | ||
log_joint_weights = [get_score(tr) + log(0.2^3) for tr in traces] | ||
lml_expected = logsumexp(log_joint_weights) | ||
@test lml_est ≈ lml_expected | ||
@test all((jw - lml_expected) ≈ w for (jw, w) in zip(log_joint_weights, log_norm_weights)) | ||
|
||
# test that polynomial is most likely quadratic | ||
degree_probs = sum(exp.(log_norm_weights), dims=(2, 3, 4)) | ||
@test argmax(vec(degree_probs)) == 2 | ||
|
||
# test that MAP trace recovers the original coefficients | ||
map_trace_idx = argmax(log_norm_weights) | ||
map_trace = traces[map_trace_idx] | ||
@test map_trace[:degree] == 2 | ||
@test map_trace[(:coeff, 0)] == 0.5 | ||
@test map_trace[(:coeff, 1)] == 0.1 | ||
@test map_trace[(:coeff, 2)] == -0.5 | ||
|
||
end |