Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move predict from Turing #716

Merged
merged 19 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ DynamicPPLZygoteRulesExt = ["ZygoteRules"]
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.8.4, 0.9"
AbstractPPL = "0.10.1"
Accessors = "0.1"
BangBang = "0.4.1"
Bijectors = "0.13.18, 0.14, 0.15"
Expand Down
142 changes: 142 additions & 0 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,148 @@
return keys(c.info.varname_to_symbol)
end

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
in `chain`, and return the resulting `Chains`.

The `model` passed to `predict` is often different from the one used to generate `chain`.
Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
data points), while the model you pass to `predict` may mark these same variables as missing
or unobserved. Calling `predict` then leverages the previously inferred parameter values to
simulate what new, unobserved data might look like, given your posterior beliefs.

For each parameter configuration in `chain`:
1. All random variables present in `chain` are fixed to their sampled values.
2. Any variables not included in `chain` are sampled from their prior distributions.

If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
predictive distribution.

# Examples
```jldoctest
using AbstractMCMC, Distributions, DynamicPPL, Random

@model function linear_reg(x, y, σ = 0.1)
β ~ Normal(0, 1)
for i in eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
end

# Generate synthetic chain using known ground truth parameter
ground_truth_β = 2.0

# Create chain of samples from a normal distribution centered on ground truth
β_chain = MCMCChains.Chains(
rand(Normal(ground_truth_β, 0.002), 1000), [:β,]
)

# Generate predictions for two test points
xs_test = [10.1, 10.2]

m_train = linear_reg(xs_test, fill(missing, length(xs_test)))

predictions = DynamicPPL.AbstractPPL.predict(
Random.default_rng(), m_train, β_chain
)

ys_pred = vec(mean(Array(predictions); dims=1))

# Check if predictions match expected values within tolerance
(
isapprox(ys_pred[1], ground_truth_β * xs_test[1], atol = 0.01),
isapprox(ys_pred[2], ground_truth_β * xs_test[2], atol = 0.01)
)

# output

(true, true)
```
"""
function DynamicPPL.predict(
rng::DynamicPPL.Random.AbstractRNG,
model::DynamicPPL.Model,
chain::MCMCChains.Chains;
include_all=false,
)
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
varinfo = DynamicPPL.VarInfo(model)

iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
predictive_samples = map(iters) do (sample_idx, chain_idx)
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
model(rng, varinfo, DynamicPPL.SampleFromPrior())

vals = DynamicPPL.values_as_in_model(model, varinfo)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually changing the behavior from Turing.jl's implementation. This will result in also including variables used in := statements, which is not currently done.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooooh nice catch; thanks! Hmm, uncertain if this is desired behavior though 😕

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw your issue on :=, totally understand the concern here. But if we are not exporting predict, we can change this in near future, also we might want to use fix in the future, so the behavior will be right then.

We would need to make a minor release of Turing if we change this now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But isn't this the purpose of this PR? To move the predict from Turing.jl to DynamicPPL.jl?

also we might want to use fix in the future

Whether we're using fix or not is just an internal impl detail, and is not relevant for its usage, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But isn't this the purpose of this PR? To move the predict from Turing.jl to DynamicPPL.jl?

Ideally, I would want this PR to do a proper implementation of predict in DynamicPPL. But now, I am okay with the PR being only a first step towards that.

Whether we're using fix or not is just an internal impl detail, and is not relevant for its usage, right?

what I was trying to say is that, with fix it should have the right behavior (with regards to :=). Of course not the only way to reach the desired behavior.

Copy link
Member

@yebai yebai Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improving it in a separate PR sounds good, but please create an issue to track @torfjelde's comment.

varname_vals = mapreduce(
collect,
vcat,
map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)),
)

return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo))
end

chain_result = reduce(
MCMCChains.chainscat,
[
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for
chain_idx in 1:size(predictive_samples, 2)
],
)
parameter_names = if include_all
MCMCChains.names(chain_result, :parameters)

Check warning on line 138 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L138

Added line #L138 was not covered by tests
else
filter(
k -> !(k in MCMCChains.names(parameter_only_chain, :parameters)),
names(chain_result, :parameters),
)
end
return chain_result[parameter_names]
end

function _predictive_samples_to_arrays(predictive_samples)
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()

sample_dicts = map(predictive_samples) do sample
varname_value_pairs = sample.varname_and_values
varnames = map(first, varname_value_pairs)
values = map(last, varname_value_pairs)
for varname in varnames
push!(variable_names_set, varname)
end

return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values))
end

variable_names = collect(variable_names_set)
variable_values = [
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
key in variable_names
]

return variable_names, variable_values
end

function _predictive_samples_to_chains(predictive_samples)
variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples)
variable_names_symbols = map(Symbol, variable_names)

internal_parameters = [:lp]
log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1)

parameter_names = [variable_names_symbols; internal_parameters]
parameter_values = hcat(variable_values, log_probabilities)
parameter_values = MCMCChains.concretize(parameter_values)

return MCMCChains.Chains(
parameter_values, parameter_names, (internals=internal_parameters,)
)
end

"""
returned(model::Model, chain::MCMCChains.Chains)

Expand Down
4 changes: 3 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using AbstractPPL
using Bijectors
using Compat
using Distributions
using OrderedCollections: OrderedDict
using OrderedCollections: OrderedCollections, OrderedDict

using AbstractMCMC: AbstractMCMC
using ADTypes: ADTypes
Expand Down Expand Up @@ -40,6 +40,8 @@ import Base:
keys,
haskey

import AbstractPPL: predict

# VarInfo
export AbstractVarInfo,
VarInfo,
Expand Down
20 changes: 20 additions & 0 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,26 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
end
end

"""
predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo})

Generate samples from the posterior predictive distribution by evaluating `model` at each set
of parameter values provided in `chain`. The number of posterior predictive samples matches
the length of `chain`. The returned `AbstractVarInfo`s will contain both the posterior parameter values
and the predicted values.
"""
function predict(
rng::Random.AbstractRNG, model::Model, chain::AbstractArray{<:AbstractVarInfo}
)
varinfo = DynamicPPL.VarInfo(model)
return map(chain) do params_varinfo
vi = deepcopy(varinfo)
DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple))
model(rng, vi, SampleFromPrior())
return vi
end
end

"""
returned(model::Model, parameters::NamedTuple)
returned(model::Model, values, keys)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.8.4, 0.9"
AbstractPPL = "0.10.1"
Accessors = "0.1"
Bijectors = "0.15.1"
Combinatorics = "1"
Expand Down
8 changes: 4 additions & 4 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,27 +202,27 @@ end
s, m = retval.s, retval.m

# Keword approach.
model_fixed = fix(model; s=s)
model_fixed = DynamicPPL.fix(model; s=s)
@test model_fixed().s == s
@test model_fixed().m != m
# A fixed variable should not contribute at all to the logjoint.
# Assuming `condition` is correctly implemented, the following should hold.
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))

# Positional approach.
model_fixed = fix(model, (; s))
model_fixed = DynamicPPL.fix(model, (; s))
@test model_fixed().s == s
@test model_fixed().m != m
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))

# Pairs approach.
model_fixed = fix(model, @varname(s) => s)
model_fixed = DynamicPPL.fix(model, @varname(s) => s)
@test model_fixed().s == s
@test model_fixed().m != m
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))

# Dictionary approach.
model_fixed = fix(model, Dict(@varname(s) => s))
model_fixed = DynamicPPL.fix(model, Dict(@varname(s) => s))
@test model_fixed().s == s
@test model_fixed().m != m
@test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m))
Expand Down
2 changes: 2 additions & 0 deletions test/ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@
@test size(chain_generated) == (1000, 1)
@test mean(chain_generated) ≈ 0 atol = 0.1
end

# test for `predict` is in `test/model.jl`
105 changes: 105 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,4 +429,109 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
@test getlogp(varinfo_linked) ≈ getlogp(varinfo_linked_result)
end
end

@testset "predict" begin
@testset "with MCMCChains.Chains" begin
DynamicPPL.Random.seed!(100)

@model function linear_reg(x, y, σ=0.1)
β ~ Normal(0, 1)
for i in eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
end

@model function linear_reg_vec(x, y, σ=0.1)
β ~ Normal(0, 1)
return y ~ MvNormal(β .* x, σ^2 * I)
end

ground_truth_β = 2
β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [:β])

xs_test = [10 + 0.1, 10 + 2 * 0.1]
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, β_chain)

ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
@test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01

# Ensure that `rng` is respected
rng = MersenneTwister(42)
predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2])
predictions2 = DynamicPPL.predict(
MersenneTwister(42), m_lin_reg_test, β_chain[1:2]
)
@test all(Array(predictions1) .== Array(predictions2))

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, β_chain)
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))

@test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01

# Multiple chains
multiple_β_chain = MCMCChains.Chains(
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), [:β]
)
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
@test size(multiple_β_chain, 3) == size(predictions, 3)

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
@test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01
end

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred_vec = vec(
mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)
)
@test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01
end
end

@testset "with AbstractVector{<:AbstractVarInfo}" begin
@model function linear_reg(x, y, σ=0.1)
β ~ Normal(1, 1)
for i in eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
end

ground_truth_β = 2.0
# the data will be ignored, as we are generating samples from the prior
xs_train = 1:0.1:10
ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train))
m_lin_reg = linear_reg(xs_train, ys_train)
chain = [evaluate!!(m_lin_reg)[2] for _ in 1:10000]

# chain is generated from the prior
@test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1

xs_test = [10 + 0.1, 10 + 2 * 0.1]
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
predicted_vis = DynamicPPL.predict(m_lin_reg_test, chain)

@test size(predicted_vis) == size(chain)
@test Set(keys(predicted_vis[1])) ==
Set([@varname(β), @varname(y[1]), @varname(y[2])])
# because β samples are from the prior, the std will be larger
@test mean([
predicted_vis[i][@varname(y[1])] for i in eachindex(predicted_vis)
]) ≈ 1.0 * xs_test[1] rtol = 0.1
@test mean([
predicted_vis[i][@varname(y[2])] for i in eachindex(predicted_vis)
]) ≈ 1.0 * xs_test[2] rtol = 0.1
end
end
end
Loading