Skip to content

Commit

Permalink
Add enumerative inference function and test cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
ztangent committed Oct 21, 2024
1 parent a5fc8e3 commit 1c46274
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 0 deletions.
131 changes: 131 additions & 0 deletions src/inference/enumerative.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""
(traces, log_norm_weights, lml_est) = enumerative_inference(
model::GenerativeFunction, model_args::Tuple,
observations::ChoiceMap, choice_vol_iter
)
Run enumerative inference over a `model`, given `observations` and an iterator over
choice maps and their associated volumes (`choice_vol_iter`), specifying the choices
to be iterated over. Return an array of traces and associated log-weights with the
same shape as `choice_vol_iter`. The log-weight of each trace is normalized, and
corresponds to the log probablity of the volume of sample space that the trace
represents. Also return an estimate of the log marginal likelihood of the
observations (`lml_est`).
All addresses in the `observations` choice map must be sampled by the model when
given the model arguments. The same constraint applies to choice maps enumerated
over by `choice_vol_iter`, which must also avoid sharing addresses with the
`observations`. An iterator over a grid of choice maps and volumes can be constructed
with [`choice_vol_grid`](@ref).
"""
function enumerative_inference(
model::GenerativeFunction{T,U}, model_args::Tuple,
observations::ChoiceMap, choice_vol_iter::I
) where {T,U,I}
if Base.IteratorSize(I) isa Base.HasShape
traces = Array{U}(undef, size(choice_vol_iter))
log_weights = Array{Float64}(undef, size(choice_vol_iter))
elseif Base.IteratorSize(I) isa Base.HasLength
traces = Vector{U}(undef, length(choice_vol_iter))
log_weights = Vector{Float64}(undef, length(choice_vol_iter))
else
choice_vol_iter = collect(choice_vol_iter)
traces = Vector{U}(undef, length(choice_vol_iter))
log_weights = Vector{Float64}(undef, length(choice_vol_iter))
end
for (i, (choices, vol)) in enumerate(choice_vol_iter)
constraints = merge(observations, choices)
(traces[i], log_weight) = generate(model, model_args, constraints)
log_weights[i] = log_weight + log(vol)
end
log_total_weight = logsumexp(log_weights)
log_normalized_weights = log_weights .- log_total_weight
return (traces, log_normalized_weights, log_total_weight)
end

"""
choice_vol_grid((addr, vals, [support, dims])::Tuple...; anchor=:midpoint)
Given tuples of the form `(addr, vals, [support, dims])`, construct an iterator
over tuples of the form `(choices::ChoiceMap, vol::Real)` via grid enumeration.
Each `addr` is an address of a random choice, and `vals` are the corresponding
values or intervals to enumerate over. The (optional) `support` denotes whether
each random choice is `:discrete` (default) or `:continuous`. This controls how
the grid is constructed:
- `support = :discrete`: The grid iterates over each value in `vals`.
- `support = :continuous` and `dims == Val(1)`: The grid iterates over the
anchors of 1D intervals whose endpoints are given by `vals`.
- `support = :continuous` and `dims == Val(N)` where `N` > 1: The grid iterates
over the anchors of multi-dimensional regions defined `vals`, which is a tuple
of interval endpoints for each dimension.
Continuous choices are assumed to have `dims = Val(1)` dimensions by default.
The `anchor` keyword argument controls which point in each interval is used as
the anchor (`:left`, `:right`, or `:midpoint`).
The volume `vol` associated with each set of `choices` in the grid is given by
the product of the volumes of each continuous region used to construct those
choices. If all addresses enumerated over are `:discrete`, then `vol = 1.0`.
"""
function choice_vol_grid(grid_specs::Tuple...; anchor::Symbol=:midpoint)
val_iter = (expand_grid_spec_to_values(spec...; anchor=anchor)
for spec in grid_specs)
val_iter = Iterators.product(val_iter...)
vol_iter = (expand_grid_spec_to_volumes(spec...) for spec in grid_specs)
vol_iter = Iterators.product(vol_iter...)
choice_vol_iter = Iterators.map(zip(val_iter, vol_iter)) do (vals, vols)
return (choicemap(vals...), prod(vols))
end
return choice_vol_iter
end

function expand_grid_spec_to_values(
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1);
anchor::Symbol = :midpoint
) where {N}
if support == :discrete
return ((addr, v) for v in vals)
elseif support == :continuous && N == 1
if anchor == :left
vals = @view(vals[begin:end-1])
elseif anchor == :right
vals = @view(vals[begin+1:end])
else
vals = @view(vals[begin:end-1]) .+ (diff(vals) ./ 2)
end
return ((addr, v) for v in vals)
elseif support == :continuous && N > 1
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`"
vals = map(vals) do vs
if anchor == :left
vs = @view(vs[begin:end-1])
elseif anchor == :right
vs = @view(vs[begin+1:end])
else
vs = @view(vs[begin:end-1]) .+ (diff(vs) ./ 2)
end
return vs
end
return ((addr, v) for v in Iterators.product(vals...))
else
error("Support must be :discrete or :continuous")
end
end

function expand_grid_spec_to_volumes(
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1)
) where {N}
if support == :discrete
return ones(length(vals))
elseif support == :continuous && N == 1
return diff(vals)
elseif support == :continuous && N > 1
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`"
diffs = Iterators.product((diff(vs) for vs in vals)...)
return (prod(ds) for ds in diffs)
else
error("Support must be :discrete or :continuous")
end
end

export enumerative_inference, choice_vol_grid
1 change: 1 addition & 0 deletions src/inference/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ include("hmc.jl")
include("mala.jl")
include("elliptical_slice.jl")

include("enumerative.jl")
include("importance.jl")
include("particle_filter.jl")
include("map_optimize.jl")
Expand Down
80 changes: 80 additions & 0 deletions test/inference/enumerative.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
@testset "enumerative inference" begin

# polynomial regression model
@gen function poly_model(n::Int, xs)
degree ~ uniform_discrete(1, n)
coeffs = zeros(n+1)
for d in 0:n
coeffs[d+1] = {(:coeff, d)} ~ uniform(-1, 1)
end
ys = zeros(length(xs))
for (i, x) in enumerate(xs)
x_powers = x .^ (0:n)
y_mean = sum(coeffs[d+1] * x_powers[d+1] for d in 0:degree)
ys[i] = {(:y, i)} ~ normal(y_mean, 0.1)
end
return ys
end

# synthetic dataset
coeffs = [0.5, 0.1, -0.5]
xs = collect(0.5:0.5:3.0)
ys = [(coeffs' * [x .^ d for d in 0:2]) for x in xs]

observations = choicemap()
for (i, y) in enumerate(ys)
observations[(:y, i)] = y
end

# test construction of choicemap-volume grid
grid = choice_vol_grid(
(:degree, 1:2),
((:coeff, 0), -1:0.2:1, :continuous),
((:coeff, 1), -1:0.2:1, :continuous),
((:coeff, 2), -1:0.2:1, :continuous),
anchor = :midpoint
)

@test size(grid) == (2, 10, 10, 10)
@test length(grid) == 2000

choices, vol = first(grid)
@test choices == choicemap(
(:degree, 1),
((:coeff, 0), -0.9), ((:coeff, 1), -0.9), ((:coeff, 2), -0.9),
)
@test vol 0.2 * 0.2 * 0.2

test_choices(n::Int, cs) =
cs[:degree] in 1:n && all(-1.0 <= cs[(:coeff, d)] <= 1.0 for d in 1:n)

@test all(test_choices(2, choices) for (choices, vol) in grid)
@test all(vol (0.2 * 0.2 * 0.2) for (choices, vol) in grid)

# run enumerative inference over grid
traces, log_norm_weights, lml_est =
enumerative_inference(poly_model, (2, xs), observations, grid)

@test size(traces) == (2, 10, 10, 10)
@test length(traces) == 2000
@test all(test_choices(2, tr) for tr in traces)

# test that log-weights are as expected
log_joint_weights = [get_score(tr) + log(0.2^3) for tr in traces]
lml_expected = logsumexp(log_joint_weights)
@test lml_est lml_expected
@test all((jw - lml_expected) w for (jw, w) in zip(log_joint_weights, log_norm_weights))

# test that polynomial is most likely quadratic
degree_probs = sum(exp.(log_norm_weights), dims=(2, 3, 4))
@test argmax(vec(degree_probs)) == 2

# test that MAP trace recovers the original coefficients
map_trace_idx = argmax(log_norm_weights)
map_trace = traces[map_trace_idx]
@test map_trace[:degree] == 2
@test map_trace[(:coeff, 0)] == 0.5
@test map_trace[(:coeff, 1)] == 0.1
@test map_trace[(:coeff, 2)] == -0.5

end

0 comments on commit 1c46274

Please sign in to comment.