Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict values_as_in_model API #778

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.33.1"
version = "0.34.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
41 changes: 15 additions & 26 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ wants to extract the realization of a model in a constrained space.
# Fields
$(TYPEDFIELDS)
"""
struct ValuesAsInModelContext{T,C<:AbstractContext} <: AbstractContext
struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext
"values that are extracted from the model"
values::T
values::OrderedDict
"whether to extract variables on the LHS of :="
include_colon_eq::Bool
"child context"
Expand Down Expand Up @@ -114,34 +114,32 @@ function dot_tilde_assume(
end

"""
values_as_in_model(model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(rng::Random.AbstractRNG, model::Model, include_colon_eq::Bool[, varinfo::AbstractVarInfo, context::AbstractContext])
values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext])
Get the values of `varinfo` as they would be seen in the model.
If no `varinfo` is provided, then this is effectively the same as
[`Base.rand(rng::Random.AbstractRNG, model::Model)`](@ref).
More specifically, this method attempts to extract the realization _as seen in
the model_. For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a
realization that is compatible with `truncated(Normal(); lower=0)` -- i.e. one
where the value of `x[1]` is positive -- regardless of whether `varinfo` is
working in unconstrained space.
More specifically, this method attempts to extract the realization _as seen in the model_.
For example, `x[1] ~ truncated(Normal(); lower=0)` will result in a realization compatible
with `truncated(Normal(); lower=0)` regardless of whether `varinfo` is working in unconstrained
space.
Hence this method is a "safe" way of obtaining realizations in constrained space at the cost
of additional model evaluations.
Hence this method is a "safe" way of obtaining realizations in constrained
space at the cost of additional model evaluations.
# Arguments
- `model::Model`: model to extract realizations from.
- `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`.
- `varinfo::AbstractVarInfo`: variable information to use for the extraction.
- `context::AbstractContext`: context to use for the extraction. If `rng` is specified, then `context`
will be wrapped in a [`SamplingContext`](@ref) with the provided `rng`.
- `context::AbstractContext`: base context to use for the extraction. Defaults
to `DynamicPPL.DefaultContext()`.
# Examples
## When `VarInfo` fails
The following demonstrates a common pitfall when working with [`VarInfo`](@ref) and constrained variables.
The following demonstrates a common pitfall when working with [`VarInfo`](@ref)
and constrained variables.
```jldoctest
julia> using Distributions, StableRNGs
Expand Down Expand Up @@ -191,19 +189,10 @@ true
function values_as_in_model(
model::Model,
include_colon_eq::Bool,
varinfo::AbstractVarInfo=VarInfo(),
varinfo::AbstractVarInfo,
context::AbstractContext=DefaultContext(),
)
context = ValuesAsInModelContext(include_colon_eq, context)
evaluate!!(model, varinfo, context)
return context.values
end
function values_as_in_model(
rng::Random.AbstractRNG,
model::Model,
include_colon_eq::Bool,
varinfo::AbstractVarInfo=VarInfo(),
context::AbstractContext=DefaultContext(),
)
return values_as_in_model(model, true, varinfo, SamplingContext(rng, context))
end
16 changes: 0 additions & 16 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,22 +429,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
end
end
end

@testset "check that sampling obeys rng if passed" begin
@model function f()
x ~ Normal(0)
return y ~ Normal(x)
end
model = f()
# Call values_as_in_model with the rng
values = values_as_in_model(Random.Xoshiro(43), model, false)
# Check that they match the values that would be used if vi was seeded
# with that seed instead
expected_vi = VarInfo(Random.Xoshiro(43), model)
for vn in keys(values)
@test values[vn] == expected_vi[vn]
end
end
end

@testset "Erroneous model call" begin
Expand Down
Loading