diff --git a/Project.toml b/Project.toml index 95342249c..97969944d 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,7 @@ DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.8.4, 0.9" +AbstractPPL = "0.10.1" Accessors = "0.1" BangBang = "0.4.1" Bijectors = "0.13.18, 0.14, 0.15" diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 41efcb15c..06cde3bac 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -48,64 +48,59 @@ end Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample in `chain`, and return the resulting `Chains`. +The `model` passed to `predict` is often different from the one used to generate `chain`. +Typically, the model from which `chain` originated treats certain variables as observed (i.e., +data points), while the model you pass to `predict` may mark these same variables as missing +or unobserved. Calling `predict` then leverages the previously inferred parameter values to +simulate what new, unobserved data might look like, given your posterior beliefs. + +For each parameter configuration in `chain`: +1. All random variables present in `chain` are fixed to their sampled values. +2. Any variables not included in `chain` are sampled from their prior distributions. + If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by the samples in `chain`. This is useful when you want to sample only new variables from the posterior predictive distribution. # Examples ```jldoctest -julia> using DynamicPPL, AbstractMCMC, AdvancedHMC, ForwardDiff; - -julia> @model function linear_reg(x, y, σ = 0.1) - β ~ Normal(0, 1) - for i ∈ eachindex(y) - y[i] ~ Normal(β * x[i], σ) - end - end; - -julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn(); - -julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train); - -julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test); - -julia> m_train = linear_reg(xs_train, ys_train, σ); +using AbstractMCMC, Distributions, DynamicPPL, Random -julia> n_train_logdensity_function = DynamicPPL.LogDensityFunction(m_train, DynamicPPL.VarInfo(m_train)); +@model function linear_reg(x, y, σ = 0.1) + β ~ Normal(0, 1) + for i in eachindex(y) + y[i] ~ Normal(β * x[i], σ) + end +end -julia> chain_lin_reg = AbstractMCMC.sample(n_train_logdensity_function, NUTS(0.65), 200; chain_type=MCMCChains.Chains, param_names=[:β], discard_initial=100) -┌ Info: Found initial step size -└ ϵ = 0.003125 +# Generate synthetic chain using known ground truth parameter +ground_truth_β = 2.0 -julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ); +# Create chain of samples from a normal distribution centered on ground truth +β_chain = MCMCChains.Chains( + rand(Normal(ground_truth_β, 0.002), 1000), [:β,] +) -julia> predictions = predict(m_test, chain_lin_reg) -Object of type Chains, with data of type 100×2×1 Array{Float64,3} +# Generate predictions for two test points +xs_test = [10.1, 10.2] -Iterations = 1:100 -Thinning interval = 1 -Chains = 1 -Samples per chain = 100 -parameters = y[1], y[2] +m_train = linear_reg(xs_test, fill(missing, length(xs_test))) -2-element Array{ChainDataFrame,1} +predictions = DynamicPPL.AbstractPPL.predict( + Random.default_rng(), m_train, β_chain +) -Summary Statistics - parameters mean std naive_se mcse ess r_hat - ────────── ─────── ────── ──────── ─────── ──────── ────── - y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922 - y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903 +ys_pred = vec(mean(Array(predictions); dims=1)) -Quantiles - parameters 2.5% 25.0% 50.0% 75.0% 97.5% - ────────── ─────── ─────── ─────── ─────── ─────── - y[1] 20.0342 20.1188 20.2135 20.2588 20.4188 - y[2] 20.1870 20.3178 20.3839 20.4466 20.5895 +# Check if predictions match expected values within tolerance +( + isapprox(ys_pred[1], ground_truth_β * xs_test[1], atol = 0.01), + isapprox(ys_pred[2], ground_truth_β * xs_test[2], atol = 0.01) +) -julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1)); +# output -julia> sum(abs2, ys_test - ys_pred) ≤ 0.1 -true +(true, true) ``` """ function DynamicPPL.predict( @@ -115,14 +110,11 @@ function DynamicPPL.predict( include_all=false, ) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - prototypical_varinfo = DynamicPPL.VarInfo(model) + varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) - varinfo = deepcopy(prototypical_varinfo) - DynamicPPL.setval_and_resample!( - varinfo, parameter_only_chain, sample_idx, chain_idx - ) + DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) model(rng, varinfo, DynamicPPL.SampleFromPrior()) vals = DynamicPPL.values_as_in_model(model, varinfo) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c314e6c6d..c1cdbd94e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -40,6 +40,8 @@ import Base: keys, haskey +import AbstractPPL: predict + # VarInfo export AbstractVarInfo, VarInfo, diff --git a/src/model.jl b/src/model.jl index 037ed8379..2bad6f1fe 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1145,19 +1145,23 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end """ - predict([rng::AbstractRNG,] model::Model, chain; include_all=false) + predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) -Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample -in `chain`. - -If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by -the samples in `chain`. This is useful when you want to sample only new variables from the posterior -predictive distribution. +Generate samples from the posterior predictive distribution by evaluating `model` at each set +of parameter values provided in `chain`. The number of posterior predictive samples matches +the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values +and the predicted values. """ -function predict(model::Model, chain; include_all=false) - # this is only defined in `ext/DynamicPPLMCMCChainsExt.jl` - # TODO: add other methods for different type of `chain` arguments: e.g., `VarInfo`, `NamedTuple`, and `OrderedDict` - return predict(Random.default_rng(), model, chain; include_all) +function predict( + rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo} +) + varinfo = DynamicPPL.VarInfo(model) + return map(chain) do params_varinfo + vi = deepcopy(varinfo) + DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) + model(rng, vi, SampleFromPrior()) + return vi + end end """ diff --git a/test/Project.toml b/test/Project.toml index 11ebeaad8..c7583c672 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,7 +2,6 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" -AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" @@ -34,9 +33,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.8.4, 0.9" +AbstractPPL = "0.10.1" Accessors = "0.1" -AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6" Bijectors = "0.15.1" Combinatorics = "1" Compat = "4.3.0" diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 8693c3b02..3ba5edfe1 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -8,169 +8,4 @@ @test mean(chain_generated) ≈ 0 atol = 0.1 end -@testset "predict" begin - DynamicPPL.Random.seed!(100) - - @model function linear_reg(x, y, σ=0.1) - β ~ Normal(0, 1) - - for i in eachindex(y) - y[i] ~ Normal(β * x[i], σ) - end - end - - @model function linear_reg_vec(x, y, σ=0.1) - β ~ Normal(0, 1) - return y ~ MvNormal(β .* x, σ^2 * I) - end - - f(x) = 2 * x + 0.1 * randn() - - Δ = 0.1 - xs_train = 0:Δ:10 - ys_train = f.(xs_train) - xs_test = [10 + Δ, 10 + 2 * Δ] - ys_test = f.(xs_test) - - # Infer - m_lin_reg = linear_reg(xs_train, ys_train) - chain_lin_reg = sample( - DynamicPPL.LogDensityFunction(m_lin_reg), - AdvancedHMC.NUTS(0.65), - 1000; - chain_type=MCMCChains.Chains, - param_names=[:β], - discard_initial=100, - n_adapt=100, - ) - - # Predict on two last indices - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) - predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) - - ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) - - # test like this depends on the variance of the posterior - # this only makes sense if the posterior variance is about 0.002 - @test sum(abs2, ys_test - ys_pred) ≤ 0.1 - - # Ensure that `rng` is respected - predictions1 = let rng = MersenneTwister(42) - DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) - end - predictions2 = let rng = MersenneTwister(42) - DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) - end - @test all(Array(predictions1) .== Array(predictions2)) - - # Predict on two last indices for vectorized - m_lin_reg_test = linear_reg_vec(xs_test, missing) - predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) - ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1)) - - @test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 - - # Multiple chains - chain_lin_reg = sample( - DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)), - AdvancedHMC.NUTS(0.65), - MCMCThreads(), - 1000, - 2; - chain_type=MCMCChains.Chains, - param_names=[:β], - discard_initial=100, - n_adapt=100, - ) - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) - predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) - - @test size(chain_lin_reg, 3) == size(predictions, 3) - - for chain_idx in MCMCChains.chains(chain_lin_reg) - ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)) - @test sum(abs2, ys_test - ys_pred) ≤ 0.1 - end - - # Predict on two last indices for vectorized - m_lin_reg_test = linear_reg_vec(xs_test, missing) - predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) - - for chain_idx in MCMCChains.chains(chain_lin_reg) - ys_pred_vec = vec(mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)) - @test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 - end - - # https://github.com/TuringLang/Turing.jl/issues/1352 - @model function simple_linear1(x, y) - intercept ~ Normal(0, 1) - coef ~ MvNormal(zeros(2), I) - coef = reshape(coef, 1, size(x, 1)) - - mu = vec(intercept .+ coef * x) - error ~ truncated(Normal(0, 1), 0, Inf) - return y ~ MvNormal(mu, error^2 * I) - end - - @model function simple_linear2(x, y) - intercept ~ Normal(0, 1) - coef ~ filldist(Normal(0, 1), 2) - coef = reshape(coef, 1, size(x, 1)) - - mu = vec(intercept .+ coef * x) - error ~ truncated(Normal(0, 1), 0, Inf) - return y ~ MvNormal(mu, error^2 * I) - end - - @model function simple_linear3(x, y) - intercept ~ Normal(0, 1) - coef = Vector(undef, 2) - for i in axes(coef, 1) - coef[i] ~ Normal(0, 1) - end - coef = reshape(coef, 1, size(x, 1)) - - mu = vec(intercept .+ coef * x) - error ~ truncated(Normal(0, 1), 0, Inf) - return y ~ MvNormal(mu, error^2 * I) - end - - @model function simple_linear4(x, y) - intercept ~ Normal(0, 1) - coef1 ~ Normal(0, 1) - coef2 ~ Normal(0, 1) - coef = [coef1, coef2] - coef = reshape(coef, 1, size(x, 1)) - - mu = vec(intercept .+ coef * x) - error ~ truncated(Normal(0, 1), 0, Inf) - return y ~ MvNormal(mu, error^2 * I) - end - - x = randn(2, 100) - y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)] - - param_names = Dict( - simple_linear1 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], - simple_linear2 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], - simple_linear3 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], - simple_linear4 => [:intercept, :coef1, :coef2, :error], - ) - @testset "$model" for model in - [simple_linear1, simple_linear2, simple_linear3, simple_linear4] - m = model(x, y) - chain = sample( - DynamicPPL.LogDensityFunction(m), - AdvancedHMC.NUTS(0.65), - 400; - initial_params=rand(4), - chain_type=MCMCChains.Chains, - param_names=param_names[model], - discard_initial=100, - n_adapt=100, - ) - chain_predict = DynamicPPL.predict(model(x, missing), chain) - mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)] - @test mean(abs2, mean_prediction - y) ≤ 1e-3 - end -end +# test for `predict` is in `test/model.jl` diff --git a/test/model.jl b/test/model.jl index a19cb29d2..cb1dbc735 100644 --- a/test/model.jl +++ b/test/model.jl @@ -429,4 +429,109 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test getlogp(varinfo_linked) ≈ getlogp(varinfo_linked_result) end end + + @testset "predict" begin + @testset "with MCMCChains.Chains" begin + DynamicPPL.Random.seed!(100) + + @model function linear_reg(x, y, σ=0.1) + β ~ Normal(0, 1) + for i in eachindex(y) + y[i] ~ Normal(β * x[i], σ) + end + end + + @model function linear_reg_vec(x, y, σ=0.1) + β ~ Normal(0, 1) + return y ~ MvNormal(β .* x, σ^2 * I) + end + + ground_truth_β = 2 + β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [:β]) + + xs_test = [10 + 0.1, 10 + 2 * 0.1] + m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) + predictions = DynamicPPL.predict(m_lin_reg_test, β_chain) + + ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) + @test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 + @test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 + + # Ensure that `rng` is respected + rng = MersenneTwister(42) + predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2]) + predictions2 = DynamicPPL.predict( + MersenneTwister(42), m_lin_reg_test, β_chain[1:2] + ) + @test all(Array(predictions1) .== Array(predictions2)) + + # Predict on two last indices for vectorized + m_lin_reg_test = linear_reg_vec(xs_test, missing) + predictions_vec = DynamicPPL.predict(m_lin_reg_test, β_chain) + ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1)) + + @test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 + @test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 + + # Multiple chains + multiple_β_chain = MCMCChains.Chains( + reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), [:β] + ) + m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) + predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) + @test size(multiple_β_chain, 3) == size(predictions, 3) + + for chain_idx in MCMCChains.chains(multiple_β_chain) + ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)) + @test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 + @test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 + end + + # Predict on two last indices for vectorized + m_lin_reg_test = linear_reg_vec(xs_test, missing) + predictions_vec = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain) + + for chain_idx in MCMCChains.chains(multiple_β_chain) + ys_pred_vec = vec( + mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1) + ) + @test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01 + @test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01 + end + end + + @testset "with AbstractVector{<:AbstractVarInfo}" begin + @model function linear_reg(x, y, σ=0.1) + β ~ Normal(1, 1) + for i in eachindex(y) + y[i] ~ Normal(β * x[i], σ) + end + end + + ground_truth_β = 2.0 + # the data will be ignored, as we are generating samples from the prior + xs_train = 1:0.1:10 + ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) + m_lin_reg = linear_reg(xs_train, ys_train) + chain = [evaluate!!(m_lin_reg)[2] for _ in 1:10000] + + # chain is generated from the prior + @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 + + xs_test = [10 + 0.1, 10 + 2 * 0.1] + m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test))) + predicted_vis = DynamicPPL.predict(m_lin_reg_test, chain) + + @test size(predicted_vis) == size(chain) + @test Set(keys(predicted_vis[1])) == + Set([@varname(β), @varname(y[1]), @varname(y[2])]) + # because β samples are from the prior, the std will be larger + @test mean([ + predicted_vis[i][@varname(y[1])] for i in eachindex(predicted_vis) + ]) ≈ 1.0 * xs_test[1] rtol = 0.1 + @test mean([ + predicted_vis[i][@varname(y[2])] for i in eachindex(predicted_vis) + ]) ≈ 1.0 * xs_test[2] rtol = 0.1 + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 6fd925cae..9f2d21990 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,5 @@ using Accessors using ADTypes -using AdvancedHMC: AdvancedHMC using DynamicPPL using AbstractMCMC using AbstractPPL