diff --git a/src/inference/enumerative.jl b/src/inference/enumerative.jl new file mode 100644 index 00000000..32d338d5 --- /dev/null +++ b/src/inference/enumerative.jl @@ -0,0 +1,131 @@ +""" + (traces, log_norm_weights, lml_est) = enumerative_inference( + model::GenerativeFunction, model_args::Tuple, + observations::ChoiceMap, choice_vol_iter + ) + +Run enumerative inference over a `model`, given `observations` and an iterator over +choice maps and their associated volumes (`choice_vol_iter`), specifying the choices +to be iterated over. Return an array of traces and associated log-weights with the +same shape as `choice_vol_iter`. The log-weight of each trace is normalized, and +corresponds to the log probablity of the volume of sample space that the trace +represents. Also return an estimate of the log marginal likelihood of the +observations (`lml_est`). + +All addresses in the `observations` choice map must be sampled by the model when +given the model arguments. The same constraint applies to choice maps enumerated +over by `choice_vol_iter`, which must also avoid sharing addresses with the +`observations`. An iterator over a grid of choice maps and volumes can be constructed +with [`choice_vol_grid`](@ref). +""" +function enumerative_inference( + model::GenerativeFunction{T,U}, model_args::Tuple, + observations::ChoiceMap, choice_vol_iter::I +) where {T,U,I} + if Base.IteratorSize(I) isa Base.HasShape + traces = Array{U}(undef, size(choice_vol_iter)) + log_weights = Array{Float64}(undef, size(choice_vol_iter)) + elseif Base.IteratorSize(I) isa Base.HasLength + traces = Vector{U}(undef, length(choice_vol_iter)) + log_weights = Vector{Float64}(undef, length(choice_vol_iter)) + else + choice_vol_iter = collect(choice_vol_iter) + traces = Vector{U}(undef, length(choice_vol_iter)) + log_weights = Vector{Float64}(undef, length(choice_vol_iter)) + end + for (i, (choices, vol)) in enumerate(choice_vol_iter) + constraints = merge(observations, choices) + (traces[i], log_weight) = generate(model, model_args, constraints) + log_weights[i] = log_weight + log(vol) + end + log_total_weight = logsumexp(log_weights) + log_normalized_weights = log_weights .- log_total_weight + return (traces, log_normalized_weights, log_total_weight) +end + +""" + choice_vol_grid((addr, vals, [support, dims])::Tuple...; anchor=:midpoint) + +Given tuples of the form `(addr, vals, [support, dims])`, construct an iterator +over tuples of the form `(choices::ChoiceMap, vol::Real)` via grid enumeration. + +Each `addr` is an address of a random choice, and `vals` are the corresponding +values or intervals to enumerate over. The (optional) `support` denotes whether +each random choice is `:discrete` (default) or `:continuous`. This controls how +the grid is constructed: +- `support = :discrete`: The grid iterates over each value in `vals`. +- `support = :continuous` and `dims == Val(1)`: The grid iterates over the + anchors of 1D intervals whose endpoints are given by `vals`. +- `support = :continuous` and `dims == Val(N)` where `N` > 1: The grid iterates + over the anchors of multi-dimensional regions defined `vals`, which is a tuple + of interval endpoints for each dimension. +Continuous choices are assumed to have `dims = Val(1)` dimensions by default. +The `anchor` keyword argument controls which point in each interval is used as +the anchor (`:left`, `:right`, or `:midpoint`). + +The volume `vol` associated with each set of `choices` in the grid is given by +the product of the volumes of each continuous region used to construct those +choices. If all addresses enumerated over are `:discrete`, then `vol = 1.0`. +""" +function choice_vol_grid(grid_specs::Tuple...; anchor::Symbol=:midpoint) + val_iter = (expand_grid_spec_to_values(spec...; anchor=anchor) + for spec in grid_specs) + val_iter = Iterators.product(val_iter...) + vol_iter = (expand_grid_spec_to_volumes(spec...) for spec in grid_specs) + vol_iter = Iterators.product(vol_iter...) + choice_vol_iter = Iterators.map(zip(val_iter, vol_iter)) do (vals, vols) + return (choicemap(vals...), prod(vols)) + end + return choice_vol_iter +end + +function expand_grid_spec_to_values( + addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1); + anchor::Symbol = :midpoint +) where {N} + if support == :discrete + return ((addr, v) for v in vals) + elseif support == :continuous && N == 1 + if anchor == :left + vals = @view(vals[begin:end-1]) + elseif anchor == :right + vals = @view(vals[begin+1:end]) + else + vals = @view(vals[begin:end-1]) .+ (diff(vals) ./ 2) + end + return ((addr, v) for v in vals) + elseif support == :continuous && N > 1 + @assert length(vals) == N "Dimension mismatch between `vals` and `dims`" + vals = map(vals) do vs + if anchor == :left + vs = @view(vs[begin:end-1]) + elseif anchor == :right + vs = @view(vs[begin+1:end]) + else + vs = @view(vs[begin:end-1]) .+ (diff(vs) ./ 2) + end + return vs + end + return ((addr, v) for v in Iterators.product(vals...)) + else + error("Support must be :discrete or :continuous") + end +end + +function expand_grid_spec_to_volumes( + addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1) +) where {N} + if support == :discrete + return ones(length(vals)) + elseif support == :continuous && N == 1 + return diff(vals) + elseif support == :continuous && N > 1 + @assert length(vals) == N "Dimension mismatch between `vals` and `dims`" + diffs = Iterators.product((diff(vs) for vs in vals)...) + return (prod(ds) for ds in diffs) + else + error("Support must be :discrete or :continuous") + end +end + +export enumerative_inference, choice_vol_grid \ No newline at end of file diff --git a/src/inference/inference.jl b/src/inference/inference.jl index d37298e2..1792b38c 100644 --- a/src/inference/inference.jl +++ b/src/inference/inference.jl @@ -21,6 +21,7 @@ include("hmc.jl") include("mala.jl") include("elliptical_slice.jl") +include("enumerative.jl") include("importance.jl") include("particle_filter.jl") include("map_optimize.jl") diff --git a/test/inference/enumerative.jl b/test/inference/enumerative.jl new file mode 100644 index 00000000..ff27c6ea --- /dev/null +++ b/test/inference/enumerative.jl @@ -0,0 +1,80 @@ +@testset "enumerative inference" begin + + # polynomial regression model + @gen function poly_model(n::Int, xs) + degree ~ uniform_discrete(1, n) + coeffs = zeros(n+1) + for d in 0:n + coeffs[d+1] = {(:coeff, d)} ~ uniform(-1, 1) + end + ys = zeros(length(xs)) + for (i, x) in enumerate(xs) + x_powers = x .^ (0:n) + y_mean = sum(coeffs[d+1] * x_powers[d+1] for d in 0:degree) + ys[i] = {(:y, i)} ~ normal(y_mean, 0.1) + end + return ys + end + + # synthetic dataset + coeffs = [0.5, 0.1, -0.5] + xs = collect(0.5:0.5:3.0) + ys = [(coeffs' * [x .^ d for d in 0:2]) for x in xs] + + observations = choicemap() + for (i, y) in enumerate(ys) + observations[(:y, i)] = y + end + + # test construction of choicemap-volume grid + grid = choice_vol_grid( + (:degree, 1:2), + ((:coeff, 0), -1:0.2:1, :continuous), + ((:coeff, 1), -1:0.2:1, :continuous), + ((:coeff, 2), -1:0.2:1, :continuous), + anchor = :midpoint + ) + + @test size(grid) == (2, 10, 10, 10) + @test length(grid) == 2000 + + choices, vol = first(grid) + @test choices == choicemap( + (:degree, 1), + ((:coeff, 0), -0.9), ((:coeff, 1), -0.9), ((:coeff, 2), -0.9), + ) + @test vol ≈ 0.2 * 0.2 * 0.2 + + test_choices(n::Int, cs) = + cs[:degree] in 1:n && all(-1.0 <= cs[(:coeff, d)] <= 1.0 for d in 1:n) + + @test all(test_choices(2, choices) for (choices, vol) in grid) + @test all(vol ≈ (0.2 * 0.2 * 0.2) for (choices, vol) in grid) + + # run enumerative inference over grid + traces, log_norm_weights, lml_est = + enumerative_inference(poly_model, (2, xs), observations, grid) + + @test size(traces) == (2, 10, 10, 10) + @test length(traces) == 2000 + @test all(test_choices(2, tr) for tr in traces) + + # test that log-weights are as expected + log_joint_weights = [get_score(tr) + log(0.2^3) for tr in traces] + lml_expected = logsumexp(log_joint_weights) + @test lml_est ≈ lml_expected + @test all((jw - lml_expected) ≈ w for (jw, w) in zip(log_joint_weights, log_norm_weights)) + + # test that polynomial is most likely quadratic + degree_probs = sum(exp.(log_norm_weights), dims=(2, 3, 4)) + @test argmax(vec(degree_probs)) == 2 + + # test that MAP trace recovers the original coefficients + map_trace_idx = argmax(log_norm_weights) + map_trace = traces[map_trace_idx] + @test map_trace[:degree] == 2 + @test map_trace[(:coeff, 0)] == 0.5 + @test map_trace[(:coeff, 1)] == 0.1 + @test map_trace[(:coeff, 2)] == -0.5 + +end