Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move predict from Turing #716

Merged
merged 19 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,139 @@
return keys(c.info.varname_to_symbol)
end

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
in `chain`, and return the resulting `Chains`.

If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
predictive distribution.

# Examples
```jldoctest
julia> using DynamicPPL, AbstractMCMC, AdvancedHMC, ForwardDiff;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here: no need to use AdvancedHMC (or any of the other packages), just construct the Chains by hand.
This also doesn't actually show that you need to import MCMCChains for this to work, which might be a good idea


julia> @model function linear_reg(x, y, σ = 0.1)
β ~ Normal(0, 1)
for i ∈ eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
end;

julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn();

julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train);

julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test);

julia> m_train = linear_reg(xs_train, ys_train, σ);

julia> n_train_logdensity_function = DynamicPPL.LogDensityFunction(m_train, DynamicPPL.VarInfo(m_train));

julia> chain_lin_reg = AbstractMCMC.sample(n_train_logdensity_function, NUTS(0.65), 200; chain_type=MCMCChains.Chains, param_names=[:β], discard_initial=100)
┌ Info: Found initial step size
└ ϵ = 0.003125

julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);

julia> predictions = predict(m_test, chain_lin_reg)
Object of type Chains, with data of type 100×2×1 Array{Float64,3}

Iterations = 1:100
Thinning interval = 1
Chains = 1
Samples per chain = 100
parameters = y[1], y[2]

2-element Array{ChainDataFrame,1}

Summary Statistics
parameters mean std naive_se mcse ess r_hat
────────── ─────── ────── ──────── ─────── ──────── ──────
y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922
y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903

Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
────────── ─────── ─────── ─────── ─────── ───────
y[1] 20.0342 20.1188 20.2135 20.2588 20.4188
y[2] 20.1870 20.3178 20.3839 20.4466 20.5895

julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1));

julia> sum(abs2, ys_test - ys_pred) ≤ 0.1
true
```
"""
function DynamicPPL.predict(
rng::DynamicPPL.Random.AbstractRNG,
model::DynamicPPL.Model,
chain::MCMCChains.Chains;
include_all=false,
)
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
vi = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
varinfos = map(iters) do (sample_idx, chain_idx)
DynamicPPL.setval_and_resample!(
deepcopy(vi), parameter_only_chain, sample_idx, chain_idx
)
end

predictive_samples = DynamicPPL.predict(rng, model, varinfos; include_all)

chain_result = reduce(
MCMCChains.chainscat,
[
_bundle_samples(predictive_samples[:, chain_idx]) for
chain_idx in 1:size(predictive_samples, 2)
],
)
parameter_names = if include_all
MCMCChains.names(chain_result, :parameters)

Check warning on line 136 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L136

Added line #L136 was not covered by tests
else
filter(
k -> !(k in MCMCChains.names(parameter_only_chain, :parameters)),
names(chain_result, :parameters),
)
end
return chain_result[parameter_names]
end

function _params_to_array(ts::Vector)
names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()

dicts = map(ts) do t
nms_and_vs = t.values
nms = map(first, nms_and_vs)
vs = map(last, nms_and_vs)
for nm in nms
push!(names_set, nm)
end
return DynamicPPL.OrderedCollections.OrderedDict(zip(nms, vs))
end

names = collect(names_set)
vals = [
get(dicts[i], key, missing) for i in eachindex(dicts), (j, key) in enumerate(names)
]

return names, vals
end

function _bundle_samples(ts::Vector{<:DynamicPPL.PredictiveSample})
varnames, vals = _params_to_array(ts)
varnames_symbol = map(Symbol, varnames)
extra_params = [:lp]
extra_values = reshape([t.logp for t in ts], :, 1)
nms = [varnames_symbol; extra_params]
parray = hcat(vals, extra_values)
parray = MCMCChains.concretize(parray)
return MCMCChains.Chains(parray, nms, (internals=extra_params,))
end

"""
generated_quantities(model::Model, chain::MCMCChains.Chains)

Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using AbstractPPL
using Bijectors
using Compat
using Distributions
using OrderedCollections: OrderedDict
using OrderedCollections: OrderedCollections, OrderedDict

using AbstractMCMC: AbstractMCMC
using ADTypes: ADTypes
Expand Down
36 changes: 36 additions & 0 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,42 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
end
end

struct PredictiveSample{T,F}
values::T
logp::F
end

"""
predict([rng::AbstractRNG,] model::Model, chain; include_all=false)

Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
in `chain`.

If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
predictive distribution.
"""
function predict(model::Model, chain; include_all=false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Turing.jl we're currently overloading StatsBase.predict, so we should probably do the same here, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree with this, but probably not time yet. Definitely after TuringLang/AbstractPPL.jl#81 is merged 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But is this PR then held up until that PR is merged then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, that PR doesn't really matter; overloading StatsBase.predict here and now just means that we'll immediately be compliant with the AbstractPPL.jl interface when that PR merges?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grey area: for me it is okay, because this PR is just about introduce a Turing-faced predict, not a user faced one yet. At the moment predict is not a public API yet

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If nothing significant is missing in TuringLang/AbstractPPL.jl#81, let's merge it and overload AbstractPPL.predict here.

return predict(Random.default_rng(), model, chain; include_all)
end

function predict(
rng::Random.AbstractRNG,
model::Model,
varinfos::AbstractArray{<:AbstractVarInfo};
include_all=false,
)
predictive_samples = Array{PredictiveSample}(undef, size(varinfos))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need the PredictiveSample here?

My original suggestion was just to use Vector{<:OrderedDict} for the return-value (an abstractly typed PredictiveSample doesn't really offer anything beyond this, does it?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't think too deep about this. A new type certainly is easier to dispatch on, but may not be necessary. Let me look into it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we don't need to dispatch on this, do we?

Also, maybe it makes more sense to follow the convetion of return the same type as the input type, i.e. in this case we should return a AbstractArray{<:AbstractVarInfo} and in the Chains case we return Chains

for i in eachindex(varinfos)
model(rng, varinfos[i], SampleFromPrior())
vals = values_as_in_model(model, varinfos[i])
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
params = mapreduce(collect, vcat, iters)
predictive_samples[i] = PredictiveSample(params, getlogp(varinfos[i]))
end
return predictive_samples
end

"""
generated_quantities(model::Model, parameters::NamedTuple)
generated_quantities(model::Model, values, keys)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this doesn't quite seem worth it to test predict, no? What's the reasoning here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add anything or change the implementation in this PR.

Agree AHMC is heavy dep, but tests like https://github.com/TuringLang/DynamicPPL.jl/blob/fd1277b7201477448d3257cab65557b850bcf5b4/test/ext/DynamicPPLMCMCChainsExt.jl#L48C1-L55C45
rely on quality of samples

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but should just replace them with samples from the prior or something. This is just checking that the statistics are correct; it doesn't matter if these statistics are from the prior or posterior 🤷

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, would it be really bad to make AdvancedHMC be a test dependency of DynamicPPL? (again, I don't like this either, but it's not too bad, I would be for adding an issue for removing this dependency later than tempering more with this PR anymore)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't look at this PR properly until Wednesday, but in https://github.com/TuringLang/DynamicPPL.jl/pull/733/files#diff-3981168ff1709b3f48c35e40f491c26d9b91fc29373e512f1272f3b928cea6c0 I wrote a function that generates a chain by sampling from the prior. (It's called make_chain_from_prior if the link doesn't bring you to the right place)
Feel free to take it if you think it's useful :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sunxd3, @penelopeysm the posterior of Bayesian linear regression can be obtained in closed form (i.e. it is a Gaussian, see here). I suggest that

  1. add this BLR model to DynamicPPL test models
  2. implement its analytical posterior
  3. sample from the analytical posterior directly and drop the AHMC deps.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though the closed-form posterior is a good idea, there's really no need to run this test on posterior samples:) These were just some stats that were picked to have something to compare to; prior chain is the way to go I think 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prior chain make sense: should we generate samples from prior, take out samples of a particular variable, and try to predict it?

Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Expand Down Expand Up @@ -32,6 +33,7 @@ AbstractMCMC = "5"
AbstractPPL = "0.8.4, 0.9"
Accessors = "0.1"
Bijectors = "0.13.9, 0.14"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
Combinatorics = "1"
Compat = "4.3.0"
Distributions = "0.25"
Expand Down
167 changes: 167 additions & 0 deletions test/ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,170 @@
@test size(chain_generated) == (1000, 1)
@test mean(chain_generated) ≈ 0 atol = 0.1
end

@testset "predict" begin
DynamicPPL.Random.seed!(100)

@model function linear_reg(x, y, σ=0.1)
β ~ Normal(0, 1)

for i in eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
end

@model function linear_reg_vec(x, y, σ=0.1)
β ~ Normal(0, 1)
return y ~ MvNormal(β .* x, σ^2 * I)
end

f(x) = 2 * x + 0.1 * randn()

Δ = 0.1
xs_train = 0:Δ:10
ys_train = f.(xs_train)
xs_test = [10 + Δ, 10 + 2 * Δ]
ys_test = f.(xs_test)

# Infer
m_lin_reg = linear_reg(xs_train, ys_train)
chain_lin_reg = sample(
DynamicPPL.LogDensityFunction(m_lin_reg),
AdvancedHMC.NUTS(0.65),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really doesn't seem necessary to use NUTS here. Just construct a Chains by hand or something, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same reason as above: some tests relies on the quality of the samples

1000;
chain_type=MCMCChains.Chains,
param_names=[:β],
discard_initial=100,
n_adapt=100,
)

# Predict on two last indices
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg)

ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))

# test like this depends on the variance of the posterior
# this only makes sense if the posterior variance is about 0.002
@test sum(abs2, ys_test - ys_pred) ≤ 0.1

# Ensure that `rng` is respected
predictions1 = let rng = MersenneTwister(42)
DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2])
end
predictions2 = let rng = MersenneTwister(42)
DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2])
end
@test all(Array(predictions1) .== Array(predictions2))

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg)
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))

@test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1

# Multiple chains
chain_lin_reg = sample(
DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)),
AdvancedHMC.NUTS(0.65),
MCMCThreads(),
1000,
2;
chain_type=MCMCChains.Chains,
param_names=[:β],
discard_initial=100,
n_adapt=100,
)
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg)

@test size(chain_lin_reg, 3) == size(predictions, 3)

for chain_idx in MCMCChains.chains(chain_lin_reg)
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
@test sum(abs2, ys_test - ys_pred) ≤ 0.1
end

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg)

for chain_idx in MCMCChains.chains(chain_lin_reg)
ys_pred_vec = vec(mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1))
@test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1
end

# https://github.com/TuringLang/Turing.jl/issues/1352
@model function simple_linear1(x, y)
intercept ~ Normal(0, 1)
coef ~ MvNormal(zeros(2), I)
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

@model function simple_linear2(x, y)
intercept ~ Normal(0, 1)
coef ~ filldist(Normal(0, 1), 2)
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

@model function simple_linear3(x, y)
intercept ~ Normal(0, 1)
coef = Vector(undef, 2)
for i in axes(coef, 1)
coef[i] ~ Normal(0, 1)
end
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

@model function simple_linear4(x, y)
intercept ~ Normal(0, 1)
coef1 ~ Normal(0, 1)
coef2 ~ Normal(0, 1)
coef = [coef1, coef2]
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

x = randn(2, 100)
y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)]

param_names = Dict(
simple_linear1 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error],
simple_linear2 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error],
simple_linear3 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error],
simple_linear4 => [:intercept, :coef1, :coef2, :error],
)
@testset "$model" for model in
[simple_linear1, simple_linear2, simple_linear3, simple_linear4]
m = model(x, y)
chain = sample(
DynamicPPL.LogDensityFunction(m),
AdvancedHMC.NUTS(0.65),
400;
initial_params=rand(4),
chain_type=MCMCChains.Chains,
param_names=param_names[model],
discard_initial=100,
n_adapt=100,
)
chain_predict = DynamicPPL.predict(model(x, missing), chain)
mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)]
@test mean(abs2, mean_prediction - y) ≤ 1e-3
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Accessors
using ADTypes
using AdvancedHMC: AdvancedHMC
using DynamicPPL
using AbstractMCMC
using AbstractPPL
Expand Down
Loading