Skip to content

Commit

Permalink
Switch to log-volumes for numeric stability.
Browse files Browse the repository at this point in the history
  • Loading branch information
ztangent committed Oct 21, 2024
1 parent 1c46274 commit 7104004
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 23 deletions.
39 changes: 20 additions & 19 deletions src/inference/enumerative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@
)
Run enumerative inference over a `model`, given `observations` and an iterator over
choice maps and their associated volumes (`choice_vol_iter`), specifying the choices
to be iterated over. Return an array of traces and associated log-weights with the
same shape as `choice_vol_iter`. The log-weight of each trace is normalized, and
corresponds to the log probablity of the volume of sample space that the trace
represents. Also return an estimate of the log marginal likelihood of the
observations (`lml_est`).
choice maps and their associated log-volumes (`choice_vol_iter`), specifying the
choices to be iterated over. An iterator over a grid of choice maps and log-volumes
can be constructed with [`choice_vol_grid`](@ref).
Return an array of traces and associated log-weights with the same shape as
`choice_vol_iter`. The log-weight of each trace is normalized, and corresponds
to the log probability of the volume of sample space that the trace represents.
Also return an estimate of the log marginal likelihood of the observations (`lml_est`).
All addresses in the `observations` choice map must be sampled by the model when
given the model arguments. The same constraint applies to choice maps enumerated
over by `choice_vol_iter`, which must also avoid sharing addresses with the
`observations`. An iterator over a grid of choice maps and volumes can be constructed
with [`choice_vol_grid`](@ref).
`observations`.
"""
function enumerative_inference(
model::GenerativeFunction{T,U}, model_args::Tuple,
Expand All @@ -33,10 +34,10 @@ function enumerative_inference(
traces = Vector{U}(undef, length(choice_vol_iter))
log_weights = Vector{Float64}(undef, length(choice_vol_iter))
end
for (i, (choices, vol)) in enumerate(choice_vol_iter)
for (i, (choices, log_vol)) in enumerate(choice_vol_iter)
constraints = merge(observations, choices)
(traces[i], log_weight) = generate(model, model_args, constraints)
log_weights[i] = log_weight + log(vol)
log_weights[i] = log_weight + log_vol
end
log_total_weight = logsumexp(log_weights)
log_normalized_weights = log_weights .- log_total_weight
Expand All @@ -47,7 +48,7 @@ end
choice_vol_grid((addr, vals, [support, dims])::Tuple...; anchor=:midpoint)
Given tuples of the form `(addr, vals, [support, dims])`, construct an iterator
over tuples of the form `(choices::ChoiceMap, vol::Real)` via grid enumeration.
over tuples of the form `(choices::ChoiceMap, log_vol::Real)` via grid enumeration.
Each `addr` is an address of a random choice, and `vals` are the corresponding
values or intervals to enumerate over. The (optional) `support` denotes whether
Expand All @@ -63,9 +64,9 @@ Continuous choices are assumed to have `dims = Val(1)` dimensions by default.
The `anchor` keyword argument controls which point in each interval is used as
the anchor (`:left`, `:right`, or `:midpoint`).
The volume `vol` associated with each set of `choices` in the grid is given by
the product of the volumes of each continuous region used to construct those
choices. If all addresses enumerated over are `:discrete`, then `vol = 1.0`.
The log-volume `log_vol` associated with each set of `choices` in the grid is given
by the log-product of the volumes of each continuous region used to construct those
choices. If all addresses enumerated over are `:discrete`, then `log_vol = 0.0`.
"""
function choice_vol_grid(grid_specs::Tuple...; anchor::Symbol=:midpoint)
val_iter = (expand_grid_spec_to_values(spec...; anchor=anchor)
Expand All @@ -74,7 +75,7 @@ function choice_vol_grid(grid_specs::Tuple...; anchor::Symbol=:midpoint)
vol_iter = (expand_grid_spec_to_volumes(spec...) for spec in grid_specs)
vol_iter = Iterators.product(vol_iter...)
choice_vol_iter = Iterators.map(zip(val_iter, vol_iter)) do (vals, vols)
return (choicemap(vals...), prod(vols))
return (choicemap(vals...), sum(vols))
end
return choice_vol_iter
end
Expand Down Expand Up @@ -116,13 +117,13 @@ function expand_grid_spec_to_volumes(
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1)
) where {N}
if support == :discrete
return ones(length(vals))
return zeros(length(vals))
elseif support == :continuous && N == 1
return diff(vals)
return log.(diff(vals))
elseif support == :continuous && N > 1
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`"
diffs = Iterators.product((diff(vs) for vs in vals)...)
return (prod(ds) for ds in diffs)
diffs = Iterators.product((log.(diff(vs)) for vs in vals)...)
return (sum(ds) for ds in diffs)
else
error("Support must be :discrete or :continuous")
end
Expand Down
8 changes: 4 additions & 4 deletions test/inference/enumerative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,18 @@
@test size(grid) == (2, 10, 10, 10)
@test length(grid) == 2000

choices, vol = first(grid)
choices, log_vol = first(grid)
@test choices == choicemap(
(:degree, 1),
((:coeff, 0), -0.9), ((:coeff, 1), -0.9), ((:coeff, 2), -0.9),
)
@test vol 0.2 * 0.2 * 0.2
@test log_vol log(0.2^3)

test_choices(n::Int, cs) =
cs[:degree] in 1:n && all(-1.0 <= cs[(:coeff, d)] <= 1.0 for d in 1:n)

@test all(test_choices(2, choices) for (choices, vol) in grid)
@test all(vol (0.2 * 0.2 * 0.2) for (choices, vol) in grid)
@test all(test_choices(2, choices) for (choices, _) in grid)
@test all(log_vol log(0.2^3) for (_, log_vol) in grid)

# run enumerative inference over grid
traces, log_norm_weights, lml_est =
Expand Down

0 comments on commit 7104004

Please sign in to comment.