diff --git a/Project.toml b/Project.toml index 92fb67ddd..32951e5cb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.30.2" +version = "0.30.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index c91fb1fe0..8a2679d09 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -108,86 +108,14 @@ function DynamicPPL.generated_quantities( varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) - if DynamicPPL.supports_varname_indexing(chain) - varname_pairs = _varname_pairs_with_varname_indexing( - chain, varinfo, sample_idx, chain_idx - ) - else - varname_pairs = _varname_pairs_without_varname_indexing( - chain, varinfo, sample_idx, chain_idx - ) - end - fixed_model = DynamicPPL.fix(model, Dict(varname_pairs)) - return fixed_model() + # TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702. + # Update the varinfo with the current sample and make variables not present in `chain` + # to be sampled. + DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) + # NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to + # `deepcopy` the `varinfo` before passing it to the `model`. + model(deepcopy(varinfo)) end end -""" - _varname_pairs_with_varname_indexing( - chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx - ) - -Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values -from the chain. - -This implementation assumes `chain` can be indexed using variable names, and is the -preffered implementation. -""" -function _varname_pairs_with_varname_indexing( - chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx -) - vns = DynamicPPL.varnames(chain) - vn_parents = Iterators.map(vns) do vn - # The call nested_setindex_maybe! is used to handle cases where vn is not - # the variable name used in the model, but rather subsumed by one. Except - # for the subsumption part, this could be - # vn => getindex_varname(chain, sample_idx, vn, chain_idx) - # TODO(mhauru) This call to nested_setindex_maybe! is unintuitive. - DynamicPPL.nested_setindex_maybe!( - varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn - ) - end - varname_pairs = Iterators.map(Iterators.filter(!isnothing, vn_parents)) do vn_parent - vn_parent => varinfo[vn_parent] - end - return varname_pairs -end - -""" -Check which keys in `key_strings` are subsumed by `vn_string` and return the their values. - -The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and -won't catch all cases. We should get rid of this if we can. -""" -# TODO(mhauru) See docstring above. -function _vcat_subsumed_values(vn_string, values, key_strings) - indices = findall(Base.Fix1(DynamicPPL.subsumes_string, vn_string), key_strings) - return !isempty(indices) ? reduce(vcat, values[indices]) : nothing -end - -""" - _varname_pairs_without_varname_indexing( - chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx - ) - -Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values -from the chain. - -This implementation does not assume that `chain` can be indexed using variable names. It is -thus not guaranteed to work in cases where the variable names have complex subsumption -patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`. -""" -function _varname_pairs_without_varname_indexing( - chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx -) - values = chain.value[sample_idx, :, chain_idx] - keys = Base.keys(chain) - keys_strings = map(string, keys) - varname_pairs = [ - vn => _vcat_subsumed_values(string(vn), values, keys_strings) for - vn in Base.keys(varinfo) - ] - return varname_pairs -end - end