Skip to content

Commit

Permalink
update implementation and tests; no longer using AdvancedHMC
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Dec 20, 2024
1 parent 7b172e2 commit a3fc8b1
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 229 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ DynamicPPLZygoteRulesExt = ["ZygoteRules"]
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.8.4, 0.9"
AbstractPPL = "0.10.1"
Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.13.18, 0.14, 0.15"
Expand Down
86 changes: 39 additions & 47 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,64 +48,59 @@ end
Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
in `chain`, and return the resulting `Chains`.
The `model` passed to `predict` is often different from the one used to generate `chain`.
Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
data points), while the model you pass to `predict` may mark these same variables as missing
or unobserved. Calling `predict` then leverages the previously inferred parameter values to
simulate what new, unobserved data might look like, given your posterior beliefs.
For each parameter configuration in `chain`:
1. All random variables present in `chain` are fixed to their sampled values.
2. Any variables not included in `chain` are sampled from their prior distributions.
If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
predictive distribution.
# Examples
```jldoctest
julia> using DynamicPPL, AbstractMCMC, AdvancedHMC, ForwardDiff;
julia> @model function linear_reg(x, y, σ = 0.1)
β ~ Normal(0, 1)
for i ∈ eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
end;
julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn();
julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train);
julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test);
julia> m_train = linear_reg(xs_train, ys_train, σ);
using AbstractMCMC, Distributions, DynamicPPL, Random
julia> n_train_logdensity_function = DynamicPPL.LogDensityFunction(m_train, DynamicPPL.VarInfo(m_train));
@model function linear_reg(x, y, σ = 0.1)
β ~ Normal(0, 1)
for i in eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
end
julia> chain_lin_reg = AbstractMCMC.sample(n_train_logdensity_function, NUTS(0.65), 200; chain_type=MCMCChains.Chains, param_names=[:β], discard_initial=100)
┌ Info: Found initial step size
└ ϵ = 0.003125
# Generate synthetic chain using known ground truth parameter
ground_truth_β = 2.0
julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);
# Create chain of samples from a normal distribution centered on ground truth
β_chain = MCMCChains.Chains(
rand(Normal(ground_truth_β, 0.002), 1000), [:β,]
)
julia> predictions = predict(m_test, chain_lin_reg)
Object of type Chains, with data of type 100×2×1 Array{Float64,3}
# Generate predictions for two test points
xs_test = [10.1, 10.2]
Iterations = 1:100
Thinning interval = 1
Chains = 1
Samples per chain = 100
parameters = y[1], y[2]
m_train = linear_reg(xs_test, fill(missing, length(xs_test)))
2-element Array{ChainDataFrame,1}
predictions = DynamicPPL.AbstractPPL.predict(
Random.default_rng(), m_train, β_chain
)
Summary Statistics
parameters mean std naive_se mcse ess r_hat
────────── ─────── ────── ──────── ─────── ──────── ──────
y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922
y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903
ys_pred = vec(mean(Array(predictions); dims=1))
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
────────── ─────── ─────── ─────── ─────── ───────
y[1] 20.0342 20.1188 20.2135 20.2588 20.4188
y[2] 20.1870 20.3178 20.3839 20.4466 20.5895
# Check if predictions match expected values within tolerance
(
isapprox(ys_pred[1], ground_truth_β * xs_test[1], atol = 0.01),
isapprox(ys_pred[2], ground_truth_β * xs_test[2], atol = 0.01)
)
julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1));
# output
julia> sum(abs2, ys_test - ys_pred) ≤ 0.1
true
(true, true)
```
"""
function DynamicPPL.predict(
Expand All @@ -115,14 +110,11 @@ function DynamicPPL.predict(
include_all=false,
)
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
prototypical_varinfo = DynamicPPL.VarInfo(model)
varinfo = DynamicPPL.VarInfo(model)

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
predictive_samples = map(iters) do (sample_idx, chain_idx)
varinfo = deepcopy(prototypical_varinfo)
DynamicPPL.setval_and_resample!(
varinfo, parameter_only_chain, sample_idx, chain_idx
)
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
model(rng, varinfo, DynamicPPL.SampleFromPrior())

vals = DynamicPPL.values_as_in_model(model, varinfo)
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ import Base:
keys,
haskey

import AbstractPPL: predict

# VarInfo
export AbstractVarInfo,
VarInfo,
Expand Down
26 changes: 15 additions & 11 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1145,19 +1145,23 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
end

"""
predict([rng::AbstractRNG,] model::Model, chain; include_all=false)
predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo})
Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
in `chain`.
If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
predictive distribution.
Generate samples from the posterior predictive distribution by evaluating `model` at each set
of parameter values provided in `chain`. The number of posterior predictive samples matches
the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values
and the predicted values.
"""
function predict(model::Model, chain; include_all=false)
# this is only defined in `ext/DynamicPPLMCMCChainsExt.jl`
# TODO: add other methods for different type of `chain` arguments: e.g., `VarInfo`, `NamedTuple`, and `OrderedDict`
return predict(Random.default_rng(), model, chain; include_all)
function predict(
rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo}
)
varinfo = DynamicPPL.VarInfo(model)
return map(chain) do params_varinfo
vi = deepcopy(varinfo)
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
model(rng, vi, SampleFromPrior())
return vi
end
end

"""
Expand Down
4 changes: 1 addition & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Expand Down Expand Up @@ -34,9 +33,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.8.4, 0.9"
AbstractPPL = "0.10.1"
Accessors = "0.1"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
Bijectors = "0.15.1"
Combinatorics = "1"
Compat = "4.3.0"
Expand Down
167 changes: 1 addition & 166 deletions test/ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,169 +8,4 @@
@test mean(chain_generated) 0 atol = 0.1
end

@testset "predict" begin
DynamicPPL.Random.seed!(100)

@model function linear_reg(x, y, σ=0.1)
β ~ Normal(0, 1)

for i in eachindex(y)
y[i] ~ Normal* x[i], σ)
end
end

@model function linear_reg_vec(x, y, σ=0.1)
β ~ Normal(0, 1)
return y ~ MvNormal.* x, σ^2 * I)
end

f(x) = 2 * x + 0.1 * randn()

Δ = 0.1
xs_train = 0:Δ:10
ys_train = f.(xs_train)
xs_test = [10 + Δ, 10 + 2 * Δ]
ys_test = f.(xs_test)

# Infer
m_lin_reg = linear_reg(xs_train, ys_train)
chain_lin_reg = sample(
DynamicPPL.LogDensityFunction(m_lin_reg),
AdvancedHMC.NUTS(0.65),
1000;
chain_type=MCMCChains.Chains,
param_names=[],
discard_initial=100,
n_adapt=100,
)

# Predict on two last indices
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg)

ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))

# test like this depends on the variance of the posterior
# this only makes sense if the posterior variance is about 0.002
@test sum(abs2, ys_test - ys_pred) 0.1

# Ensure that `rng` is respected
predictions1 = let rng = MersenneTwister(42)
DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2])
end
predictions2 = let rng = MersenneTwister(42)
DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2])
end
@test all(Array(predictions1) .== Array(predictions2))

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg)
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))

@test sum(abs2, ys_test - ys_pred_vec) 0.1

# Multiple chains
chain_lin_reg = sample(
DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)),
AdvancedHMC.NUTS(0.65),
MCMCThreads(),
1000,
2;
chain_type=MCMCChains.Chains,
param_names=[],
discard_initial=100,
n_adapt=100,
)
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg)

@test size(chain_lin_reg, 3) == size(predictions, 3)

for chain_idx in MCMCChains.chains(chain_lin_reg)
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
@test sum(abs2, ys_test - ys_pred) 0.1
end

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg)

for chain_idx in MCMCChains.chains(chain_lin_reg)
ys_pred_vec = vec(mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1))
@test sum(abs2, ys_test - ys_pred_vec) 0.1
end

# https://github.com/TuringLang/Turing.jl/issues/1352
@model function simple_linear1(x, y)
intercept ~ Normal(0, 1)
coef ~ MvNormal(zeros(2), I)
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

@model function simple_linear2(x, y)
intercept ~ Normal(0, 1)
coef ~ filldist(Normal(0, 1), 2)
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

@model function simple_linear3(x, y)
intercept ~ Normal(0, 1)
coef = Vector(undef, 2)
for i in axes(coef, 1)
coef[i] ~ Normal(0, 1)
end
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

@model function simple_linear4(x, y)
intercept ~ Normal(0, 1)
coef1 ~ Normal(0, 1)
coef2 ~ Normal(0, 1)
coef = [coef1, coef2]
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

x = randn(2, 100)
y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)]

param_names = Dict(
simple_linear1 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error],
simple_linear2 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error],
simple_linear3 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error],
simple_linear4 => [:intercept, :coef1, :coef2, :error],
)
@testset "$model" for model in
[simple_linear1, simple_linear2, simple_linear3, simple_linear4]
m = model(x, y)
chain = sample(
DynamicPPL.LogDensityFunction(m),
AdvancedHMC.NUTS(0.65),
400;
initial_params=rand(4),
chain_type=MCMCChains.Chains,
param_names=param_names[model],
discard_initial=100,
n_adapt=100,
)
chain_predict = DynamicPPL.predict(model(x, missing), chain)
mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)]
@test mean(abs2, mean_prediction - y) 1e-3
end
end
# test for `predict` is in `test/model.jl`
Loading

0 comments on commit a3fc8b1

Please sign in to comment.