-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move predict
from Turing
#716
Changes from 10 commits
1c1c907
bdf90b4
c7d08b0
a425c41
41471f6
90d99ca
ea23b7c
76ef40f
304b63e
53b6749
fcd7c3d
3dc742a
30208ec
bf38627
fd1277b
86eab6b
7b172e2
a3fc8b1
da7fa1c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1203,6 +1203,42 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC | |
end | ||
end | ||
|
||
struct PredictiveSample{T,F} | ||
values::T | ||
logp::F | ||
end | ||
|
||
""" | ||
predict([rng::AbstractRNG,] model::Model, chain; include_all=false) | ||
|
||
Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample | ||
in `chain`. | ||
|
||
If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by | ||
the samples in `chain`. This is useful when you want to sample only new variables from the posterior | ||
predictive distribution. | ||
""" | ||
function predict(model::Model, chain; include_all=false) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In Turing.jl we're currently overloading There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree with this, but probably not time yet. Definitely after TuringLang/AbstractPPL.jl#81 is merged 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But is this PR then held up until that PR is merged then? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, that PR doesn't really matter; overloading There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Grey area: for me it is okay, because this PR is just about introduce a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If nothing significant is missing in TuringLang/AbstractPPL.jl#81, let's merge it and overload |
||
return predict(Random.default_rng(), model, chain; include_all) | ||
end | ||
|
||
function predict( | ||
rng::Random.AbstractRNG, | ||
model::Model, | ||
varinfos::AbstractArray{<:AbstractVarInfo}; | ||
include_all=false, | ||
) | ||
predictive_samples = Array{PredictiveSample}(undef, size(varinfos)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really need the My original suggestion was just to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't think too deep about this. A new type certainly is easier to dispatch on, but may not be necessary. Let me look into it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But we don't need to dispatch on this, do we? Also, maybe it makes more sense to follow the convetion of return the same type as the input type, i.e. in this case we should return a |
||
for i in eachindex(varinfos) | ||
model(rng, varinfos[i], SampleFromPrior()) | ||
vals = values_as_in_model(model, varinfos[i]) | ||
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) | ||
params = mapreduce(collect, vcat, iters) | ||
predictive_samples[i] = PredictiveSample(params, getlogp(varinfos[i])) | ||
end | ||
return predictive_samples | ||
end | ||
|
||
""" | ||
generated_quantities(model::Model, parameters::NamedTuple) | ||
generated_quantities(model::Model, values, keys) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" | ||
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" | ||
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, this doesn't quite seem worth it to test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't add anything or change the implementation in this PR. Agree AHMC is heavy dep, but tests like https://github.com/TuringLang/DynamicPPL.jl/blob/fd1277b7201477448d3257cab65557b850bcf5b4/test/ext/DynamicPPLMCMCChainsExt.jl#L48C1-L55C45 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, but should just replace them with samples from the prior or something. This is just checking that the statistics are correct; it doesn't matter if these statistics are from the prior or posterior 🤷 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, would it be really bad to make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can't look at this PR properly until Wednesday, but in https://github.com/TuringLang/DynamicPPL.jl/pull/733/files#diff-3981168ff1709b3f48c35e40f491c26d9b91fc29373e512f1272f3b928cea6c0 I wrote a function that generates a chain by sampling from the prior. (It's called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sunxd3, @penelopeysm the posterior of Bayesian linear regression can be obtained in closed form (i.e. it is a Gaussian, see here). I suggest that
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Though the closed-form posterior is a good idea, there's really no need to run this test on posterior samples:) These were just some stats that were picked to have something to compare to; prior chain is the way to go I think 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prior chain make sense: should we generate samples from prior, take out samples of a particular variable, and try to predict it? |
||
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" | ||
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" | ||
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" | ||
|
@@ -32,6 +33,7 @@ AbstractMCMC = "5" | |
AbstractPPL = "0.8.4, 0.9" | ||
Accessors = "0.1" | ||
Bijectors = "0.13.9, 0.14" | ||
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6" | ||
Combinatorics = "1" | ||
Compat = "4.3.0" | ||
Distributions = "0.25" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,170 @@ | |
@test size(chain_generated) == (1000, 1) | ||
@test mean(chain_generated) ≈ 0 atol = 0.1 | ||
end | ||
|
||
@testset "predict" begin | ||
DynamicPPL.Random.seed!(100) | ||
|
||
@model function linear_reg(x, y, σ=0.1) | ||
β ~ Normal(0, 1) | ||
|
||
for i in eachindex(y) | ||
y[i] ~ Normal(β * x[i], σ) | ||
end | ||
end | ||
|
||
@model function linear_reg_vec(x, y, σ=0.1) | ||
β ~ Normal(0, 1) | ||
return y ~ MvNormal(β .* x, σ^2 * I) | ||
end | ||
|
||
f(x) = 2 * x + 0.1 * randn() | ||
|
||
Δ = 0.1 | ||
xs_train = 0:Δ:10 | ||
ys_train = f.(xs_train) | ||
xs_test = [10 + Δ, 10 + 2 * Δ] | ||
ys_test = f.(xs_test) | ||
|
||
# Infer | ||
m_lin_reg = linear_reg(xs_train, ys_train) | ||
chain_lin_reg = sample( | ||
DynamicPPL.LogDensityFunction(m_lin_reg), | ||
AdvancedHMC.NUTS(0.65), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Really doesn't seem necessary to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same reason as above: some tests relies on the quality of the samples |
||
1000; | ||
chain_type=MCMCChains.Chains, | ||
param_names=[:β], | ||
discard_initial=100, | ||
n_adapt=100, | ||
) | ||
|
||
# Predict on two last indices | ||
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) | ||
predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
|
||
ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) | ||
|
||
# test like this depends on the variance of the posterior | ||
# this only makes sense if the posterior variance is about 0.002 | ||
@test sum(abs2, ys_test - ys_pred) ≤ 0.1 | ||
|
||
# Ensure that `rng` is respected | ||
predictions1 = let rng = MersenneTwister(42) | ||
DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) | ||
end | ||
predictions2 = let rng = MersenneTwister(42) | ||
DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) | ||
end | ||
@test all(Array(predictions1) .== Array(predictions2)) | ||
|
||
# Predict on two last indices for vectorized | ||
m_lin_reg_test = linear_reg_vec(xs_test, missing) | ||
predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1)) | ||
|
||
@test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 | ||
|
||
# Multiple chains | ||
chain_lin_reg = sample( | ||
DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)), | ||
AdvancedHMC.NUTS(0.65), | ||
MCMCThreads(), | ||
1000, | ||
2; | ||
chain_type=MCMCChains.Chains, | ||
param_names=[:β], | ||
discard_initial=100, | ||
n_adapt=100, | ||
) | ||
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) | ||
predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
|
||
@test size(chain_lin_reg, 3) == size(predictions, 3) | ||
|
||
for chain_idx in MCMCChains.chains(chain_lin_reg) | ||
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)) | ||
@test sum(abs2, ys_test - ys_pred) ≤ 0.1 | ||
end | ||
|
||
# Predict on two last indices for vectorized | ||
m_lin_reg_test = linear_reg_vec(xs_test, missing) | ||
predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
|
||
for chain_idx in MCMCChains.chains(chain_lin_reg) | ||
ys_pred_vec = vec(mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)) | ||
@test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 | ||
end | ||
|
||
# https://github.com/TuringLang/Turing.jl/issues/1352 | ||
@model function simple_linear1(x, y) | ||
intercept ~ Normal(0, 1) | ||
coef ~ MvNormal(zeros(2), I) | ||
coef = reshape(coef, 1, size(x, 1)) | ||
|
||
mu = vec(intercept .+ coef * x) | ||
error ~ truncated(Normal(0, 1), 0, Inf) | ||
return y ~ MvNormal(mu, error^2 * I) | ||
end | ||
|
||
@model function simple_linear2(x, y) | ||
intercept ~ Normal(0, 1) | ||
coef ~ filldist(Normal(0, 1), 2) | ||
coef = reshape(coef, 1, size(x, 1)) | ||
|
||
mu = vec(intercept .+ coef * x) | ||
error ~ truncated(Normal(0, 1), 0, Inf) | ||
return y ~ MvNormal(mu, error^2 * I) | ||
end | ||
|
||
@model function simple_linear3(x, y) | ||
intercept ~ Normal(0, 1) | ||
coef = Vector(undef, 2) | ||
for i in axes(coef, 1) | ||
coef[i] ~ Normal(0, 1) | ||
end | ||
coef = reshape(coef, 1, size(x, 1)) | ||
|
||
mu = vec(intercept .+ coef * x) | ||
error ~ truncated(Normal(0, 1), 0, Inf) | ||
return y ~ MvNormal(mu, error^2 * I) | ||
end | ||
|
||
@model function simple_linear4(x, y) | ||
intercept ~ Normal(0, 1) | ||
coef1 ~ Normal(0, 1) | ||
coef2 ~ Normal(0, 1) | ||
coef = [coef1, coef2] | ||
coef = reshape(coef, 1, size(x, 1)) | ||
|
||
mu = vec(intercept .+ coef * x) | ||
error ~ truncated(Normal(0, 1), 0, Inf) | ||
return y ~ MvNormal(mu, error^2 * I) | ||
end | ||
|
||
x = randn(2, 100) | ||
y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)] | ||
|
||
param_names = Dict( | ||
simple_linear1 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], | ||
simple_linear2 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], | ||
simple_linear3 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], | ||
simple_linear4 => [:intercept, :coef1, :coef2, :error], | ||
) | ||
@testset "$model" for model in | ||
[simple_linear1, simple_linear2, simple_linear3, simple_linear4] | ||
m = model(x, y) | ||
chain = sample( | ||
DynamicPPL.LogDensityFunction(m), | ||
AdvancedHMC.NUTS(0.65), | ||
400; | ||
initial_params=rand(4), | ||
chain_type=MCMCChains.Chains, | ||
param_names=param_names[model], | ||
discard_initial=100, | ||
n_adapt=100, | ||
) | ||
chain_predict = DynamicPPL.predict(model(x, missing), chain) | ||
mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)] | ||
@test mean(abs2, mean_prediction - y) ≤ 1e-3 | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here: no need to use
AdvancedHMC
(or any of the other packages), just construct theChains
by hand.This also doesn't actually show that you need to import
MCMCChains
for this to work, which might be a good idea