Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More work on VarNameVector #637

Merged
merged 76 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from 63 commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
8750255
Merge branch 'master' into torfjelde/transformations
torfjelde Apr 16, 2024
ed6ee88
Merge branch 'master' into torfjelde/transformations
torfjelde Jun 18, 2024
607bdb3
Update test/model.jl
torfjelde Jun 18, 2024
55c8098
Apply suggestions from code review
torfjelde Jun 21, 2024
cc910d5
Merge remote-tracking branch 'origin/torfjelde/transformations' into …
torfjelde Jun 27, 2024
ec9f985
Merge branch 'master' into torfjelde/transformations
torfjelde Jul 14, 2024
a079606
Merge remote-tracking branch 'origin/torfjelde/transformations' into …
torfjelde Jul 19, 2024
ad959ec
Type-stability tests are now correctly using `rand_prior_true` instead
torfjelde Jul 21, 2024
9f84070
`getindex_internal` now calls `getindex` instead of `view`, as the
torfjelde Jul 21, 2024
7d39934
Removed seemingly unnecessary definition of `getindex_internal`
torfjelde Jul 21, 2024
b554504
Fixed references to `newmetadata` which has been replaced by `replace…
torfjelde Jul 28, 2024
ddb1dfe
Made implementation of `recombine` more explicit
torfjelde Jul 28, 2024
3b08f1d
Added docstrings for `untyped_varinfo` and `typed_varinfo`
torfjelde Jul 28, 2024
96ccebe
Added TODO comment about implementing `view` for `VarInfo`
torfjelde Jul 28, 2024
beaeeaa
Fixed potential infinite recursion as suggested by @mhauru
torfjelde Jul 28, 2024
ab2c98b
added docstring to `from_vec_trnasform_for_size
torfjelde Jul 28, 2024
f1f7968
Replaced references to `vectorize(dist, x)` with `tovec(x)`
torfjelde Jul 28, 2024
6e57822
Fixed docstring
torfjelde Jul 28, 2024
841215f
Update src/extract_priors.jl
torfjelde Jul 28, 2024
78b2083
Bump minor version since this is a breaking change
torfjelde Jul 28, 2024
b6ecf7b
Merge remote-tracking branch 'origin/torfjelde/transformations' into …
torfjelde Jul 28, 2024
7100ce1
Merge branch 'master' into torfjelde/transformations
torfjelde Jul 28, 2024
bab63e1
Apply suggestions from code review
sunxd3 Jul 30, 2024
6997019
Update src/varinfo.jl
sunxd3 Jul 30, 2024
9dc7f02
Apply suggestions from code review
torfjelde Jul 30, 2024
c0f9923
Apply suggestions from code review
torfjelde Jul 30, 2024
9056928
Update src/extract_priors.jl
torfjelde Aug 6, 2024
e43dd1b
Added fix for product distributions of targets with changing support …
torfjelde Aug 6, 2024
a7673fd
Addeed tests for product of distributions with dynamic support
torfjelde Aug 6, 2024
e8d4c96
Apply suggestions from code review
torfjelde Aug 6, 2024
2fe7605
Merge remote-tracking branch 'origin/torfjelde/transformations' into …
mhauru Aug 8, 2024
ca0c951
Fix typos, improve docstrings
mhauru Aug 8, 2024
60bb054
Use Accessors rather than Setfield
mhauru Aug 8, 2024
32a9ec7
Simplify group_by_symbol
mhauru Aug 8, 2024
bc41d82
Add short_varinfo_name(::VectorVarInfo)
mhauru Aug 8, 2024
6900034
Add tests for subset
mhauru Aug 8, 2024
e9be160
Export VectorVarInfo
mhauru Aug 8, 2024
2ae8516
Tighter type bound for has_varnamevector
mhauru Aug 8, 2024
524c148
Add some VectorVarName methods
mhauru Aug 8, 2024
b076aef
Add todo notes, remove dead code, fix a typo.
mhauru Aug 8, 2024
f28b430
Bug fixes and small improvements
mhauru Aug 14, 2024
5f02494
VarNameVector improvements
mhauru Aug 15, 2024
56fac99
Improve generated_quantities and its tests
mhauru Aug 19, 2024
c793ada
Improvement to VarNameVector
mhauru Aug 19, 2024
ed2d695
Fix a test to work with VectorVarName
mhauru Aug 19, 2024
01935c8
Fix generated_quantities
mhauru Aug 19, 2024
f8d0100
Fix type stability issues
mhauru Aug 21, 2024
d4ba9f5
Various VarNameVector fixes and improvements
mhauru Aug 21, 2024
fef615d
Merge remote-tracking branch 'origin/master' into mhauru/varnamevector
mhauru Aug 22, 2024
bd67b38
Bump version number
mhauru Aug 22, 2024
06d9df5
Merge remote-tracking branch 'origin/torfjelde/varnamevector' into mh…
mhauru Aug 22, 2024
9596bea
Improvements to generated_quantities
mhauru Aug 22, 2024
b8309d2
Code formatting
mhauru Aug 22, 2024
44fc385
Code style
mhauru Aug 22, 2024
ad13acf
Add fallback implementation of findinds for VarNameVector
mhauru Aug 22, 2024
d0322b7
Rename VarNameVector to VarNamedVector
mhauru Aug 22, 2024
250010d
More renaming of VNV. Remove unused VarNamedVector.metadata field.
mhauru Aug 22, 2024
02d5187
Rename FromVec to ReshapeTransform
mhauru Aug 23, 2024
94cf179
Progress towards having VarNamedVector as storage for SimpleVarInfo
mhauru Aug 28, 2024
27bac26
Fix unflatten(vnv::VarNamedVector, vals)
mhauru Aug 29, 2024
38147da
More work on SimpleVarInfo{VarNamedVector}
mhauru Aug 29, 2024
3990914
More tests for SimpleVarInfo{VarNamedVector}
mhauru Aug 29, 2024
a7a9974
More tests for SimpleVarInfo{VarNamedVector}
mhauru Aug 29, 2024
1937359
Respond to review feedback
mhauru Aug 30, 2024
d4120c3
Add float_type_with_fallback(::Type{Union{}})
mhauru Aug 30, 2024
ca04666
Move some VNV functions to the correct file
mhauru Aug 30, 2024
bb9ae76
Fix push! for VNV
mhauru Aug 30, 2024
536a476
Rename VNV.is_transformed to VNV.is_unconstrained
mhauru Aug 30, 2024
6a029bb
Improve VNV docstring
mhauru Aug 30, 2024
d8f8b17
Add VNV inner constructor checks
mhauru Aug 30, 2024
076e478
Reorganise parts of VNV code
mhauru Aug 30, 2024
f11f007
Documentation and small fixes for VNV
mhauru Sep 2, 2024
f8361f6
Rename loosen_types!! and tighten_types, add docstrings and doctests
mhauru Sep 2, 2024
004c327
Rename VarNameVector to VarNamedVector in docs
mhauru Sep 2, 2024
5291290
Documentation and small fixes to VNV
mhauru Sep 2, 2024
3d472f4
Fix subset(::VarNamedVector, args...) for unconstrained variables.
mhauru Sep 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/internals/varinfo.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ This does incur a runtime cost as it requires re-allocation of the `ranges` in a

!!! note

Higher-dimensional arrays, e.g. `Matrix`, are handled by simply vectorizing them before storing them in the `Vector{T}`, and composing he `VarName`'s transformation with a `DynamicPPL.FromVec`.
Higher-dimensional arrays, e.g. `Matrix`, are handled by simply vectorizing them before storing them in the `Vector{T}`, and composing the `VarName`'s transformation with a `DynamicPPL.ReshapeTransform`.

Continuing from the example from the previous section, we can use a `VarInfo` with a `VarNameVector` as the `metadata` field:

Expand Down
145 changes: 138 additions & 7 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ end
_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names
function _check_varname_indexing(c::MCMCChains.Chains)
return DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using $vn.")
error("$(typeof(c)) do not support indexing using varnmes.")
end

# Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata
Expand Down Expand Up @@ -41,21 +41,152 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
return keys(c.info.varname_to_symbol)
end

"""
generated_quantities(model::Model, chain::MCMCChains.Chains)

Execute `model` for each of the samples in `chain` and return an array of the values
returned by the `model` for each sample.

# Examples
## General
Often you might have additional quantities computed inside the model that you want to
inspect, e.g.
```julia
@model function demo(x)
# sample and observe
θ ~ Prior()
x ~ Likelihood()
return interesting_quantity(θ, x)
end
m = demo(data)
chain = sample(m, alg, n)
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
# from the posterior/`chain`:
generated_quantities(m, chain) # <= results in a `Vector` of returned values
# from `interesting_quantity(θ, x)`
```
## Concrete (and simple)
```julia
julia> using DynamicPPL, Turing

julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
m_shifted ~ Normal(10, √s)
m = m_shifted - 10

for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
end

return (m, )
end
demo (generic function with 1 method)

julia> model = demo(randn(10));

julia> chain = sample(model, MH(), 10);

julia> generated_quantities(model, chain)
10×1 Array{Tuple{Float64},2}:
(2.1964758025119338,)
(2.1964758025119338,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.043088571494005024,)
(-0.16489786710222099,)
(-0.16489786710222099,)
```
"""
function DynamicPPL.generated_quantities(
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
)
chain = MCMCChains.get_sections(chain_full, :parameters)
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
# Update the varinfo with the current sample and make variables not present in `chain`
# to be sampled.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
if DynamicPPL.supports_varname_indexing(chain)
varname_pairs = _varname_pairs_with_varname_indexing(
chain, varinfo, sample_idx, chain_idx
)
else
varname_pairs = _varname_pairs_without_varname_indexing(
chain, varinfo, sample_idx, chain_idx
)
end
fixed_model = DynamicPPL.fix(model, Dict(varname_pairs))
return fixed_model()
end
end

"""
_varname_pairs_with_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)

# TODO: Some of the variables can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to `model`.
model(deepcopy(varinfo))
Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
from the chain.

This implementation assumes `chain` can be indexed using variable names, and is the
preffered implementation.
"""
function _varname_pairs_with_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
vns = DynamicPPL.varnames(chain)
vn_parents = Iterators.map(vns) do vn
# The call nested_setindex_maybe! is used to handle cases where vn is not
# the variable name used in the model, but rather subsumed by one. Except
# for the subsumption part, this could be
# vn => getindex_varname(chain, sample_idx, vn, chain_idx)
# TODO(mhauru) This call to nested_setindex_maybe! is unintuitive.
DynamicPPL.nested_setindex_maybe!(
varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn
)
end
varname_pairs = Iterators.map(Iterators.filter(!isnothing, vn_parents)) do vn_parent
vn_parent => varinfo[vn_parent]
end
return varname_pairs
end

"""
Check which keys in `key_strings` are subsumed by `vn_string` and return the their values.

The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and
won't catch all cases. We should get rid of this if we can.
"""
# TODO(mhauru) See docstring above.
function _vcat_subsumed_values(vn_string, values, key_strings)
indices = findall(Base.Fix1(DynamicPPL.subsumes_string, vn_string), key_strings)
return !isempty(indices) ? reduce(vcat, values[indices]) : nothing
end

"""
_varname_pairs_without_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)

Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values
from the chain.

This implementation does not assume that `chain` can be indexed using variable names. It is
thus not guaranteed to work in cases where the variable names have complex subsumption
patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`.
"""
function _varname_pairs_without_varname_indexing(
chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx
)
values = chain.value[sample_idx, :, chain_idx]
keys = Base.keys(chain)
keys_strings = map(string, keys)
varname_pairs = [
vn => _vcat_subsumed_values(string(vn), values, keys_strings) for
vn in Base.keys(varinfo)
]
return varname_pairs
end

end
6 changes: 3 additions & 3 deletions ext/DynamicPPLReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ else
end

function LogDensityProblemsAD.ADgradient(
ad::ADTypes.AutoReverseDiff{Tcompile}, ℓ::DynamicPPL.LogDensityFunction
) where {Tcompile}
ad::ADTypes.AutoReverseDiff, ℓ::DynamicPPL.LogDensityFunction
)
return LogDensityProblemsAD.ADgradient(
Val(:ReverseDiff),
ℓ;
compile=Val(Tcompile),
compile=Val(ad.compile),
# `getparams` can return `Vector{Real}`, in which case, `ReverseDiff` will initialize the gradients to Integer 0
# because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473
# `zero(D)` will return 0 when D is Real.
Expand Down
5 changes: 3 additions & 2 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ export AbstractVarInfo,
VarInfo,
UntypedVarInfo,
TypedVarInfo,
VectorVarInfo,
Copy link
Member

@yebai yebai Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhauru Is it possible to also introduce and test against VectorSimpleVarInfo (i.e. SimpleVarInfo with VarNameVector as storage format)? If so, can you investigate what might be needed from VarNameVector?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working on this, and it seems doable. I got a big chunk of the test suite to pass, still need to add some more tests and see if they fail.

SimpleVarInfo,
VarNameVector,
VarNamedVector,
push!!,
empty!!,
subset,
Expand Down Expand Up @@ -176,7 +177,7 @@ include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varnamevector.jl")
include("varnamedvector.jl")
include("abstract_varinfo.jl")
include("threadsafe.jl")
include("varinfo.jl")
Expand Down
6 changes: 3 additions & 3 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;

julia> # For the sake of brevity, let's just check the type.
md = values_as(vi); md.s isa DynamicPPL.Metadata
md = values_as(vi); md.s isa Union{DynamicPPL.Metadata, DynamicPPL.VarNamedVector}
true

julia> values_as(vi, NamedTuple)
Expand All @@ -321,7 +321,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;

julia> # For the sake of brevity, let's just check the type.
values_as(vi) isa DynamicPPL.Metadata
values_as(vi) isa Union{DynamicPPL.Metadata, Vector}
true

julia> values_as(vi, NamedTuple)
Expand Down Expand Up @@ -349,7 +349,7 @@ Determine the default `eltype` of the values returned by `vi[spl]`.
This should generally not be called explicitly, as it's only used in
[`matchingvalue`](@ref) to determine the default type to use in place of
type-parameters passed to the model.

This method is considered legacy, and is likely to be deprecated in the future.
"""
function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior})
Expand Down
15 changes: 12 additions & 3 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,10 @@ function assume(
if haskey(vi, vn)
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if
# that's okay.
unset_flag!(vi, vn, "del", true)
r = init(rng, dist, sampler)
f = to_maybe_linked_internal_transform(vi, vn, dist)
BangBang.setindex!!(vi, f(r), vn)
Expand Down Expand Up @@ -516,7 +519,10 @@ function get_and_set_val!(
if haskey(vi, vns[1])
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
unset_flag!(vi, vns[1], "del")
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if
# that's okay.
unset_flag!(vi, vns[1], "del", true)
r = init(rng, dist, spl, n)
for i in 1:n
vn = vns[i]
Expand Down Expand Up @@ -554,7 +560,10 @@ function get_and_set_val!(
if haskey(vi, vns[1])
# Always overwrite the parameters with new ones for `SampleFromUniform`.
if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del")
unset_flag!(vi, vns[1], "del")
# TODO(mhauru) Is it important to unset the flag here? The `true` allows us
# to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if
# that's okay.
unset_flag!(vi, vns[1], "del", true)
f = (vn, dist) -> init(rng, dist, spl)
r = f.(vns, dists)
for i in eachindex(vns)
Expand Down
79 changes: 4 additions & 75 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1201,74 +1201,6 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
end
end

"""
generated_quantities(model::Model, chain::AbstractChains)

Execute `model` for each of the samples in `chain` and return an array of the values
returned by the `model` for each sample.

# Examples
## General
Often you might have additional quantities computed inside the model that you want to
inspect, e.g.
```julia
@model function demo(x)
# sample and observe
θ ~ Prior()
x ~ Likelihood()
return interesting_quantity(θ, x)
end
m = demo(data)
chain = sample(m, alg, n)
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
# from the posterior/`chain`:
generated_quantities(m, chain) # <= results in a `Vector` of returned values
# from `interesting_quantity(θ, x)`
```
## Concrete (and simple)
```julia
julia> using DynamicPPL, Turing

julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
m_shifted ~ Normal(10, √s)
m = m_shifted - 10

for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
end

return (m, )
end
demo (generic function with 1 method)

julia> model = demo(randn(10));

julia> chain = sample(model, MH(), 10);

julia> generated_quantities(model, chain)
10×1 Array{Tuple{Float64},2}:
(2.1964758025119338,)
(2.1964758025119338,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.09270081916291417,)
(0.043088571494005024,)
(-0.16489786710222099,)
(-0.16489786710222099,)
```
"""
function generated_quantities(model::Model, chain::AbstractChains)
varinfo = VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
model(varinfo)
end
end

"""
generated_quantities(model::Model, parameters::NamedTuple)
generated_quantities(model::Model, values, keys)
Expand All @@ -1295,7 +1227,7 @@ demo (generic function with 2 methods)

julia> model = demo(randn(10));

julia> parameters = (; s = 1.0, m_shifted=10);
julia> parameters = (; s = 1.0, m_shifted=10.0);

julia> generated_quantities(model, parameters)
(0.0,)
Expand All @@ -1305,13 +1237,10 @@ julia> generated_quantities(model, values(parameters), keys(parameters))
```
"""
function generated_quantities(model::Model, parameters::NamedTuple)
varinfo = VarInfo(model)
setval_and_resample!(varinfo, values(parameters), keys(parameters))
return model(varinfo)
fixed_model = fix(model, parameters)
return fixed_model()
end

function generated_quantities(model::Model, values, keys)
varinfo = VarInfo(model)
setval_and_resample!(varinfo, values, keys)
return model(varinfo)
return generated_quantities(model, NamedTuple{keys}(values))
end
2 changes: 1 addition & 1 deletion src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ function set_values!!(
flattened_param_vals = varinfo[spl]
length(flattened_param_vals) == length(initial_params) || throw(
DimensionMismatch(
"Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(theta)))",
"Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(flattened_param_vals)))",
),
)

Expand Down
Loading