Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't include lhs of := in results of predict() #766

Merged
merged 5 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.32.2"
version = "0.33.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ function DynamicPPL.predict(
DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx)
model(rng, varinfo, DynamicPPL.SampleFromPrior())

vals = DynamicPPL.values_as_in_model(model, varinfo)
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
varname_vals = mapreduce(
collect,
vcat,
Expand Down
25 changes: 14 additions & 11 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,22 @@ $(TYPEDFIELDS)
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
"values that are extracted from the model"
values::T
"whether to extract variables on the LHS of :="
include_colon_eq::Bool
"child context"
context::C
end

ValuesAsInModelContext(values) = ValuesAsInModelContext(values, DefaultContext())
function ValuesAsInModelContext(context::AbstractContext)
return ValuesAsInModelContext(OrderedDict(), context)
function ValuesAsInModelContext(include_colon_eq, context::AbstractContext)
return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context)
end

NodeTrait(::ValuesAsInModelContext) = IsParent()
childcontext(context::ValuesAsInModelContext) = context.context
function setchildcontext(context::ValuesAsInModelContext, child)
return ValuesAsInModelContext(context.values, child)
return ValuesAsInModelContext(context.values, context.include_colon_eq, child)
end

is_extracting_values(context::ValuesAsInModelContext) = true
is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq
function is_extracting_values(context::AbstractContext)
return is_extracting_values(NodeTrait(context), context)
end
Expand Down Expand Up @@ -114,8 +114,8 @@ function dot_tilde_assume(
end

"""
values_as_in_model(model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(rng::Random.AbstractRNG, model::Model[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(rng::Random.AbstractRNG, model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])

Get the values of `varinfo` as they would be seen in the model.

Expand All @@ -132,6 +132,7 @@ of additional model evaluations.

# Arguments
- `model::Model`: model to extract realizations from.
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
Expand Down Expand Up @@ -183,24 +184,26 @@ false
julia> # Approach 2: Extract realizations using `values_as_in_model`.
# (✓) `values_as_in_model` will re-run the model and extract
# the correct realization of `y` given the new values of `x`.
lb ≤ values_as_in_model(model, varinfo_linked)[@varname(y)] ≤ ub
lb ≤ values_as_in_model(model, true, varinfo_linked)[@varname(y)] ≤ ub
true
```
"""
function values_as_in_model(
model::Model,
include_colon_eq::Bool,
varinfo::AbstractVarInfo=VarInfo(),
context::AbstractContext=DefaultContext(),
)
context = ValuesAsInModelContext(context)
context = ValuesAsInModelContext(include_colon_eq, context)
evaluate!!(model, varinfo, context)
return context.values
end
function values_as_in_model(
rng::Random.AbstractRNG,
model::Model,
include_colon_eq::Bool,
varinfo::AbstractVarInfo=VarInfo(),
context::AbstractContext=DefaultContext(),
)
return values_as_in_model(model, varinfo, SamplingContext(rng, context))
return values_as_in_model(model, true, varinfo, SamplingContext(rng, context))
end
11 changes: 9 additions & 2 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,10 +702,17 @@ module Issue537 end
@test haskey(varinfo, @varname(x))
@test !haskey(varinfo, @varname(y))

# While `values_as_in_model` should contain both `x` and `y`.
values = values_as_in_model(model, deepcopy(varinfo))
# While `values_as_in_model` should contain both `x` and `y`, if
# include_colon_eq is set to `true`.
values = values_as_in_model(model, true, deepcopy(varinfo))
@test haskey(values, @varname(x))
@test haskey(values, @varname(y))

# And if include_colon_eq is set to `false`, then `values` should
# only contain `x`.
values = values_as_in_model(model, false, deepcopy(varinfo))
@test haskey(values, @varname(x))
@test !haskey(values, @varname(y))
end
end

Expand Down
120 changes: 77 additions & 43 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
realizations = values_as_in_model(model, varinfo)
# We can set the include_colon_eq arg to false because none of
# the demo models contain :=. The behaviour when
# include_colon_eq is true is tested in test/compiler.jl
realizations = values_as_in_model(model, false, varinfo)
# Ensure that all variables are found.
vns_found = collect(keys(realizations))
@test vns ∩ vns_found == vns ∪ vns_found
Expand All @@ -393,6 +396,22 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
end
end
end

@testset "check that sampling obeys rng if passed" begin
@model function f()
x ~ Normal(0)
return y ~ Normal(x)
end
model = f()
# Call values_as_in_model with the rng
values = values_as_in_model(Random.Xoshiro(43), model, false)
# Check that they match the values that would be used if vi was seeded
# with that seed instead
expected_vi = VarInfo(Random.Xoshiro(43), model)
for vn in keys(values)
@test values[vn] == expected_vi[vn]
end
end
end

@testset "Erroneous model call" begin
Expand Down Expand Up @@ -432,72 +451,87 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()

@testset "predict" begin
@testset "with MCMCChains.Chains" begin
DynamicPPL.Random.seed!(100)

@model function linear_reg(x, y, σ=0.1)
β ~ Normal(0, 1)
for i in eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
# Insert a := block to test that it is not included in predictions
return σ2 := σ^2
end

@model function linear_reg_vec(x, y, σ=0.1)
β ~ Normal(0, 1)
return y ~ MvNormal(β .* x, σ^2 * I)
end

# Construct a chain with 'sampled values' of β
ground_truth_β = 2
β_chain = MCMCChains.Chains(rand(Normal(ground_truth_β, 0.002), 1000), [:β])

# Generate predictions from that chain
xs_test = [10 + 0.1, 10 + 2 * 0.1]
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, β_chain)

ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
@test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01

# Ensure that `rng` is respected
rng = MersenneTwister(42)
predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2])
predictions2 = DynamicPPL.predict(
MersenneTwister(42), m_lin_reg_test, β_chain[1:2]
)
@test all(Array(predictions1) .== Array(predictions2))

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, β_chain)
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))

@test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01
# Also test a vectorized model
@model function linear_reg_vec(x, y, σ=0.1)
β ~ Normal(0, 1)
return y ~ MvNormal(β .* x, σ^2 * I)
end
m_lin_reg_test_vec = linear_reg_vec(xs_test, missing)

# Multiple chains
multiple_β_chain = MCMCChains.Chains(
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), [:β]
)
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(xs_test)))
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
@test size(multiple_β_chain, 3) == size(predictions, 3)
@testset "variables in chain" begin
# Note that this also checks that variables on the lhs of :=,
# such as σ2, are not included in the resulting chain
@test Set(keys(predictions)) == Set([Symbol("y[1]"), Symbol("y[2]")])
end

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
@testset "accuracy" begin
ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
@test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01
end

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred_vec = vec(
mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)
@testset "ensure that rng is respected" begin
rng = MersenneTwister(42)
predictions1 = DynamicPPL.predict(rng, m_lin_reg_test, β_chain[1:2])
predictions2 = DynamicPPL.predict(
MersenneTwister(42), m_lin_reg_test, β_chain[1:2]
)
@test all(Array(predictions1) .== Array(predictions2))
end

@testset "accuracy on vectorized model" begin
predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, β_chain)
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))

@test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01
end

@testset "prediction from multiple chains" begin
# Normal linreg model
multiple_β_chain = MCMCChains.Chains(
reshape(rand(Normal(ground_truth_β, 0.002), 1000, 2), 1000, 1, 2), [:β]
)
predictions = DynamicPPL.predict(m_lin_reg_test, multiple_β_chain)
@test size(multiple_β_chain, 3) == size(predictions, 3)

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred = vec(
mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)
)
@test ys_pred[1] ≈ ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred[2] ≈ ground_truth_β * xs_test[2] atol = 0.01
end

# Vectorized linreg model
predictions_vec = DynamicPPL.predict(m_lin_reg_test_vec, multiple_β_chain)

for chain_idx in MCMCChains.chains(multiple_β_chain)
ys_pred_vec = vec(
mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)
)
@test ys_pred_vec[1] ≈ ground_truth_β * xs_test[1] atol = 0.01
@test ys_pred_vec[2] ≈ ground_truth_β * xs_test[2] atol = 0.01
end
end
end

@testset "with AbstractVector{<:AbstractVarInfo}" begin
Expand Down
Loading