diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f7d43470e..4015ab331 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -4,9 +4,11 @@ on: push: branches: - master + - backport-* pull_request: branches: - master + - backport-* merge_group: types: [checks_requested] diff --git a/Project.toml b/Project.toml index eab8c362c..5dde3a427 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.30.1" +version = "0.30.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -46,7 +46,7 @@ AbstractMCMC = "5" AbstractPPL = "0.8.4, 0.9" Accessors = "0.1" BangBang = "0.4.1" -Bijectors = "0.13.18" +Bijectors = "0.13.18, 0.14" ChainRulesCore = "1" Compat = "4" ConstructionBase = "1.5.4" @@ -65,7 +65,7 @@ Requires = "1" ReverseDiff = "1" Test = "1.6" ZygoteRules = "0.2" -julia = "~1.6.6, 1.7.3" +julia = "1.10" [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index eb9745552..f950de6f1 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -108,86 +108,14 @@ function DynamicPPL.returned_quantities( varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) - if DynamicPPL.supports_varname_indexing(chain) - varname_pairs = _varname_pairs_with_varname_indexing( - chain, varinfo, sample_idx, chain_idx - ) - else - varname_pairs = _varname_pairs_without_varname_indexing( - chain, varinfo, sample_idx, chain_idx - ) - end - fixed_model = DynamicPPL.fix(model, Dict(varname_pairs)) - return fixed_model() + # TODO: Use `fix` once we've addressed https://github.com/TuringLang/DynamicPPL.jl/issues/702. + # Update the varinfo with the current sample and make variables not present in `chain` + # to be sampled. + DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) + # NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to + # `deepcopy` the `varinfo` before passing it to the `model`. + model(deepcopy(varinfo)) end end -""" - _varname_pairs_with_varname_indexing( - chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx - ) - -Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values -from the chain. - -This implementation assumes `chain` can be indexed using variable names, and is the -preffered implementation. -""" -function _varname_pairs_with_varname_indexing( - chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx -) - vns = DynamicPPL.varnames(chain) - vn_parents = Iterators.map(vns) do vn - # The call nested_setindex_maybe! is used to handle cases where vn is not - # the variable name used in the model, but rather subsumed by one. Except - # for the subsumption part, this could be - # vn => getindex_varname(chain, sample_idx, vn, chain_idx) - # TODO(mhauru) This call to nested_setindex_maybe! is unintuitive. - DynamicPPL.nested_setindex_maybe!( - varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn - ) - end - varname_pairs = Iterators.map(Iterators.filter(!isnothing, vn_parents)) do vn_parent - vn_parent => varinfo[vn_parent] - end - return varname_pairs -end - -""" -Check which keys in `key_strings` are subsumed by `vn_string` and return the their values. - -The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and -won't catch all cases. We should get rid of this if we can. -""" -# TODO(mhauru) See docstring above. -function _vcat_subsumed_values(vn_string, values, key_strings) - indices = findall(Base.Fix1(DynamicPPL.subsumes_string, vn_string), key_strings) - return !isempty(indices) ? reduce(vcat, values[indices]) : nothing -end - -""" - _varname_pairs_without_varname_indexing( - chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx - ) - -Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values -from the chain. - -This implementation does not assume that `chain` can be indexed using variable names. It is -thus not guaranteed to work in cases where the variable names have complex subsumption -patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`. -""" -function _varname_pairs_without_varname_indexing( - chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx -) - values = chain.value[sample_idx, :, chain_idx] - keys = Base.keys(chain) - keys_strings = map(string, keys) - varname_pairs = [ - vn => _vcat_subsumed_values(string(vn), values, keys_strings) for - vn in Base.keys(varinfo) - ] - return varname_pairs -end - end diff --git a/src/utils.jl b/src/utils.jl index bd5d365fc..5fedd3039 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -286,8 +286,15 @@ function (f::ReshapeTransform)(x) if size(x) != f.input_size throw(DimensionMismatch("Expected input of size $(f.input_size), got $(size(x))")) end - # The call to `tovec` is only needed in case `x` is a scalar. - return reshape(tovec(x), f.output_size) + if f.output_size == () + # Specially handle the case where x is a singleton array, see + # https://github.com/JuliaDiff/ReverseDiff.jl/issues/265 and + # https://github.com/TuringLang/DynamicPPL.jl/issues/698 + return fill(x[], ()) + else + # The call to `tovec` is only needed in case `x` is a scalar. + return reshape(tovec(x), f.output_size) + end end function (inv_f::Bijectors.Inverse{<:ReshapeTransform})(x) @@ -934,10 +941,10 @@ end """ float_type_with_fallback(x) -Return type corresponding to `float(typeof(x))` if possible; otherwise return `Real`. +Return type corresponding to `float(typeof(x))` if possible; otherwise return `float(Real)`. """ -float_type_with_fallback(::Type) = Real -float_type_with_fallback(::Type{Union{}}) = Real +float_type_with_fallback(::Type) = float(Real) +float_type_with_fallback(::Type{Union{}}) = float(Real) float_type_with_fallback(::Type{T}) where {T<:Real} = float(T) """ diff --git a/src/varinfo.jl b/src/varinfo.jl index 11fe08d0f..4cf1f1b02 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1694,6 +1694,8 @@ Return the current value(s) of the random variables sampled by `spl` in `vi`. The value(s) may or may not be transformed to Euclidean space. """ +getindex(vi::UntypedVarInfo, spl::Sampler) = + copy(getindex(vi.metadata.vals, _getranges(vi, spl))) getindex(vi::VarInfo, spl::Sampler) = copy(getindex_internal(vi, _getranges(vi, spl))) function getindex(vi::TypedVarInfo, spl::Sampler) # Gets the ranges as a NamedTuple diff --git a/test/Project.toml b/test/Project.toml index 36ee4baa8..36fcd1b69 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -31,7 +31,7 @@ ADTypes = "1" AbstractMCMC = "5" AbstractPPL = "0.8.4, 0.9" Accessors = "0.1" -Bijectors = "0.13.9" +Bijectors = "0.13.9, 0.14" Combinatorics = "1" Compat = "4.3.0" Distributions = "0.25" diff --git a/test/turing/Project.toml b/test/turing/Project.toml index 0dd3f20e0..28341c20b 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -15,5 +15,5 @@ DynamicPPL = "0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30" HypothesisTests = "0.11" MCMCChains = "6" ReverseDiff = "1.15" -Turing = "0.33, 0.34" +Turing = "0.33, 0.34, 0.35" julia = "1.7" diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl index f1d805505..e5b8eb79f 100644 --- a/test/turing/varinfo.jl +++ b/test/turing/varinfo.jl @@ -342,4 +342,19 @@ model = state_space(y, length(t)) @test size(sample(model, NUTS(; adtype=AutoReverseDiff(true)), n), 1) == n end + + if Threads.nthreads() > 1 + @testset "DynamicPPL#684: OrderedDict with multiple types when multithreaded" begin + @model function f(x) + ns ~ filldist(Normal(0, 2.0), 3) + m ~ Uniform(0, 1) + return x ~ Normal(m, 1) + end + model = f(1) + chain = sample(model, NUTS(), MCMCThreads(), 10, 2) + loglikelihood(model, chain) + logprior(model, chain) + logjoint(model, chain) + end + end end