diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index 8d74fd2fa..d1d2e7116 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -132,13 +132,13 @@ Hence we obtain a "type-stable when possible"-representation by wrapping it in a ## Efficient storage and iteration -Efficient storage and iteration we achieve through implementation of the `metadata`. In particular, we do so with [`VarNameVector`](@ref): +Efficient storage and iteration we achieve through implementation of the `metadata`. In particular, we do so with [`VarNamedVector`](@ref): ```@docs -DynamicPPL.VarNameVector +DynamicPPL.VarNamedVector ``` -In a [`VarNameVector{<:VarName,Vector{T}}`](@ref), we achieve the desirata by storing the values for different `VarName`s contiguously in a `Vector{T}` and keeping track of which ranges correspond to which `VarName`s. +In a [`VarNamedVector{<:VarName,Vector{T}}`](@ref), we achieve the desirata by storing the values for different `VarName`s contiguously in a `Vector{T}` and keeping track of which ranges correspond to which `VarName`s. This does require a bit of book-keeping, in particular when it comes to insertions and deletions. Internally, this is handled by assigning each `VarName` a unique `Int` index in the `varname_to_index` field, which is then used to index into the following fields: @@ -146,7 +146,7 @@ This does require a bit of book-keeping, in particular when it comes to insertio - `ranges::Vector{UnitRange{Int}}`: the ranges of indices in the `Vector{T}` that correspond to each `VarName`. - `transforms::Vector`: the transforms associated with each `VarName`. -Mutating functions, e.g. `setindex!(vnv::VarNameVector, val, vn::VarName)`, are then treated according to the following rules: +Mutating functions, e.g. `setindex!(vnv::VarNamedVector, val, vn::VarName)`, are then treated according to the following rules: 1. If `vn` is not already present: add it to the end of `vnv.varnames`, add the `val` to the underlying `vnv.vals`, etc. @@ -156,7 +156,7 @@ Mutating functions, e.g. `setindex!(vnv::VarNameVector, val, vn::VarName)`, are 2. If `val` has a *smaller length* than the existing value for `vn`: replace existing value and mark the remaining indices as "inactive" by increasing the entry in `vnv.num_inactive` field. 3. If `val` has a *larger length* than the existing value for `vn`: expand the underlying `vnv.vals` to accommodate the new value, update all `VarName`s occuring after `vn`, and update the `vnv.ranges` to point to the new range for `vn`. -This means that `VarNameVector` is allowed to grow as needed, while "shrinking" (i.e. insertion of smaller elements) is handled by simply marking the redundant indices as "inactive". This turns out to be efficient for use-cases that we are generally interested in. +This means that `VarNamedVector` is allowed to grow as needed, while "shrinking" (i.e. insertion of smaller elements) is handled by simply marking the redundant indices as "inactive". This turns out to be efficient for use-cases that we are generally interested in. For example, we want to optimize code-paths which effectively boil down to inner-loop in the following example: @@ -195,7 +195,7 @@ DynamicPPL.contiguify! For example, one might encounter the following scenario: ```@example varinfo-design -vnv = DynamicPPL.VarNameVector(@varname(x) => [true]) +vnv = DynamicPPL.VarNamedVector(@varname(x) => [true]) println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))") for i in 1:5 @@ -210,7 +210,7 @@ end We can then insert a call to [`DynamicPPL.contiguify!`](@ref) after every insertion whenever the allocation grows too large to reduce overall memory usage: ```@example varinfo-design -vnv = DynamicPPL.VarNameVector(@varname(x) => [true]) +vnv = DynamicPPL.VarNamedVector(@varname(x) => [true]) println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))") for i in 1:5 @@ -225,13 +225,13 @@ for i in 1:5 end ``` -This does incur a runtime cost as it requires re-allocation of the `ranges` in addition to a `resize!` of the underlying `Vector{T}`. However, this also ensures that the the underlying `Vector{T}` is contiguous, which is important for performance. Hence, if we're about to do a lot of work with the `VarNameVector` without insertions, etc., it can be worth it to do a sweep to ensure that the underlying `Vector{T}` is contiguous. +This does incur a runtime cost as it requires re-allocation of the `ranges` in addition to a `resize!` of the underlying `Vector{T}`. However, this also ensures that the the underlying `Vector{T}` is contiguous, which is important for performance. Hence, if we're about to do a lot of work with the `VarNamedVector` without insertions, etc., it can be worth it to do a sweep to ensure that the underlying `Vector{T}` is contiguous. !!! note - Higher-dimensional arrays, e.g. `Matrix`, are handled by simply vectorizing them before storing them in the `Vector{T}`, and composing he `VarName`'s transformation with a `DynamicPPL.FromVec`. + Higher-dimensional arrays, e.g. `Matrix`, are handled by simply vectorizing them before storing them in the `Vector{T}`, and composing the `VarName`'s transformation with a `DynamicPPL.ReshapeTransform`. -Continuing from the example from the previous section, we can use a `VarInfo` with a `VarNameVector` as the `metadata` field: +Continuing from the example from the previous section, we can use a `VarInfo` with a `VarNamedVector` as the `metadata` field: ```@example varinfo-design # Type-unstable @@ -287,23 +287,23 @@ DynamicPPL.num_allocated(varinfo_untyped_vnv.metadata, @varname(x)) ### Performance summary -In the end, we have the following "rough" performance characteristics for `VarNameVector`: +In the end, we have the following "rough" performance characteristics for `VarNamedVector`: -| Method | Is blazingly fast? | -|:---------------------------------------:|:--------------------------------------------------------------------------------------------:| -| `getindex` | ${\color{green} \checkmark}$ | -| `setindex!` | ${\color{green} \checkmark}$ | -| `push!` | ${\color{green} \checkmark}$ | -| `delete!` | ${\color{red} \times}$ | -| `update!` on existing `VarName` | ${\color{green} \checkmark}$ if smaller or same size / ${\color{red} \times}$ if larger size | -| `values_as(::VarNameVector, Vector{T})` | ${\color{green} \checkmark}$ if contiguous / ${\color{orange} \div}$ otherwise | +| Method | Is blazingly fast? | +|:----------------------------------------:|:--------------------------------------------------------------------------------------------:| +| `getindex` | ${\color{green} \checkmark}$ | +| `setindex!` | ${\color{green} \checkmark}$ | +| `push!` | ${\color{green} \checkmark}$ | +| `delete!` | ${\color{red} \times}$ | +| `update!` on existing `VarName` | ${\color{green} \checkmark}$ if smaller or same size / ${\color{red} \times}$ if larger size | +| `values_as(::VarNamedVector, Vector{T})` | ${\color{green} \checkmark}$ if contiguous / ${\color{orange} \div}$ otherwise | ## Other methods ```@docs -DynamicPPL.replace_values(::VarNameVector, vals::AbstractVector) +DynamicPPL.replace_values(::VarNamedVector, vals::AbstractVector) ``` ```@docs; canonical=false -DynamicPPL.values_as(::VarNameVector) +DynamicPPL.values_as(::VarNamedVector) ``` diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 8c598a6a8..f362b02cc 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -11,7 +11,7 @@ end _has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names function _check_varname_indexing(c::MCMCChains.Chains) return DynamicPPL.supports_varname_indexing(c) || - error("Chains do not support indexing using $vn.") + error("$(typeof(c)) do not support indexing using varnmes.") end # Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata @@ -41,6 +41,65 @@ function DynamicPPL.varnames(c::MCMCChains.Chains) return keys(c.info.varname_to_symbol) end +""" + generated_quantities(model::Model, chain::MCMCChains.Chains) + +Execute `model` for each of the samples in `chain` and return an array of the values +returned by the `model` for each sample. + +# Examples +## General +Often you might have additional quantities computed inside the model that you want to +inspect, e.g. +```julia +@model function demo(x) + # sample and observe + θ ~ Prior() + x ~ Likelihood() + return interesting_quantity(θ, x) +end +m = demo(data) +chain = sample(m, alg, n) +# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples +# from the posterior/`chain`: +generated_quantities(m, chain) # <= results in a `Vector` of returned values + # from `interesting_quantity(θ, x)` +``` +## Concrete (and simple) +```julia +julia> using DynamicPPL, Turing + +julia> @model function demo(xs) + s ~ InverseGamma(2, 3) + m_shifted ~ Normal(10, √s) + m = m_shifted - 10 + + for i in eachindex(xs) + xs[i] ~ Normal(m, √s) + end + + return (m, ) + end +demo (generic function with 1 method) + +julia> model = demo(randn(10)); + +julia> chain = sample(model, MH(), 10); + +julia> generated_quantities(model, chain) +10×1 Array{Tuple{Float64},2}: + (2.1964758025119338,) + (2.1964758025119338,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.09270081916291417,) + (0.043088571494005024,) + (-0.16489786710222099,) + (-0.16489786710222099,) +``` +""" function DynamicPPL.generated_quantities( model::DynamicPPL.Model, chain_full::MCMCChains.Chains ) @@ -48,14 +107,86 @@ function DynamicPPL.generated_quantities( varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) - # Update the varinfo with the current sample and make variables not present in `chain` - # to be sampled. - DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) + if DynamicPPL.supports_varname_indexing(chain) + varname_pairs = _varname_pairs_with_varname_indexing( + chain, varinfo, sample_idx, chain_idx + ) + else + varname_pairs = _varname_pairs_without_varname_indexing( + chain, varinfo, sample_idx, chain_idx + ) + end + fixed_model = DynamicPPL.fix(model, Dict(varname_pairs)) + return fixed_model() + end +end + +""" + _varname_pairs_with_varname_indexing( + chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx + ) - # TODO: Some of the variables can be a view into the `varinfo`, so we need to - # `deepcopy` the `varinfo` before passing it to `model`. - model(deepcopy(varinfo)) +Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values +from the chain. + +This implementation assumes `chain` can be indexed using variable names, and is the +preffered implementation. +""" +function _varname_pairs_with_varname_indexing( + chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx +) + vns = DynamicPPL.varnames(chain) + vn_parents = Iterators.map(vns) do vn + # The call nested_setindex_maybe! is used to handle cases where vn is not + # the variable name used in the model, but rather subsumed by one. Except + # for the subsumption part, this could be + # vn => getindex_varname(chain, sample_idx, vn, chain_idx) + # TODO(mhauru) This call to nested_setindex_maybe! is unintuitive. + DynamicPPL.nested_setindex_maybe!( + varinfo, DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx), vn + ) end + varname_pairs = Iterators.map(Iterators.filter(!isnothing, vn_parents)) do vn_parent + vn_parent => varinfo[vn_parent] + end + return varname_pairs +end + +""" +Check which keys in `key_strings` are subsumed by `vn_string` and return the their values. + +The subsumption check is done with `DynamicPPL.subsumes_string`, which is quite weak, and +won't catch all cases. We should get rid of this if we can. +""" +# TODO(mhauru) See docstring above. +function _vcat_subsumed_values(vn_string, values, key_strings) + indices = findall(Base.Fix1(DynamicPPL.subsumes_string, vn_string), key_strings) + return !isempty(indices) ? reduce(vcat, values[indices]) : nothing +end + +""" + _varname_pairs_without_varname_indexing( + chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx + ) + +Get pairs of `VarName => value` for all the variables in the `varinfo`, picking the values +from the chain. + +This implementation does not assume that `chain` can be indexed using variable names. It is +thus not guaranteed to work in cases where the variable names have complex subsumption +patterns, such as if the model has a variable `x` but the chain stores `x.a[1]`. +""" +function _varname_pairs_without_varname_indexing( + chain::MCMCChains.Chains, varinfo, sample_idx, chain_idx +) + values = chain.value[sample_idx, :, chain_idx] + keys = Base.keys(chain) + keys_strings = map(string, keys) + varname_pairs = [ + vn => _vcat_subsumed_values(string(vn), values, keys_strings) for + vn in Base.keys(varinfo) + ] + return varname_pairs end end diff --git a/ext/DynamicPPLReverseDiffExt.jl b/ext/DynamicPPLReverseDiffExt.jl index 3fd174ed1..b2b378d45 100644 --- a/ext/DynamicPPLReverseDiffExt.jl +++ b/ext/DynamicPPLReverseDiffExt.jl @@ -9,12 +9,12 @@ else end function LogDensityProblemsAD.ADgradient( - ad::ADTypes.AutoReverseDiff{Tcompile}, ℓ::DynamicPPL.LogDensityFunction -) where {Tcompile} + ad::ADTypes.AutoReverseDiff, ℓ::DynamicPPL.LogDensityFunction +) return LogDensityProblemsAD.ADgradient( Val(:ReverseDiff), ℓ; - compile=Val(Tcompile), + compile=Val(ad.compile), # `getparams` can return `Vector{Real}`, in which case, `ReverseDiff` will initialize the gradients to Integer 0 # because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473 # `zero(D)` will return 0 when D is Real. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 255c4b51b..969d69936 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -45,8 +45,9 @@ export AbstractVarInfo, VarInfo, UntypedVarInfo, TypedVarInfo, + VectorVarInfo, SimpleVarInfo, - VarNameVector, + VarNamedVector, push!!, empty!!, subset, @@ -176,7 +177,7 @@ include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") -include("varnamevector.jl") +include("varnamedvector.jl") include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 7ddd09b2e..551bf87d3 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -295,7 +295,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; julia> # For the sake of brevity, let's just check the type. - md = values_as(vi); md.s isa DynamicPPL.Metadata + md = values_as(vi); md.s isa Union{DynamicPPL.Metadata, DynamicPPL.VarNamedVector} true julia> values_as(vi, NamedTuple) @@ -321,7 +321,7 @@ julia> # Just use an example model to construct the `VarInfo` because we're lazy julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; julia> # For the sake of brevity, let's just check the type. - values_as(vi) isa DynamicPPL.Metadata + values_as(vi) isa Union{DynamicPPL.Metadata, Vector} true julia> values_as(vi, NamedTuple) @@ -349,7 +349,7 @@ Determine the default `eltype` of the values returned by `vi[spl]`. This should generally not be called explicitly, as it's only used in [`matchingvalue`](@ref) to determine the default type to use in place of type-parameters passed to the model. - + This method is considered legacy, and is likely to be deprecated in the future. """ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromPrior}) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 13231837f..1961965ca 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -240,7 +240,10 @@ function assume( if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") + # TODO(mhauru) Is it important to unset the flag here? The `true` allows us + # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if + # that's okay. + unset_flag!(vi, vn, "del", true) r = init(rng, dist, sampler) f = to_maybe_linked_internal_transform(vi, vn, dist) BangBang.setindex!!(vi, f(r), vn) @@ -516,7 +519,10 @@ function get_and_set_val!( if haskey(vi, vns[1]) # Always overwrite the parameters with new ones for `SampleFromUniform`. if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!(vi, vns[1], "del") + # TODO(mhauru) Is it important to unset the flag here? The `true` allows us + # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if + # that's okay. + unset_flag!(vi, vns[1], "del", true) r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] @@ -554,7 +560,10 @@ function get_and_set_val!( if haskey(vi, vns[1]) # Always overwrite the parameters with new ones for `SampleFromUniform`. if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!(vi, vns[1], "del") + # TODO(mhauru) Is it important to unset the flag here? The `true` allows us + # to ignore the fact that for VarNamedVector this does nothing, but I'm unsure if + # that's okay. + unset_flag!(vi, vns[1], "del", true) f = (vn, dist) -> init(rng, dist, spl) r = f.(vns, dists) for i in eachindex(vns) diff --git a/src/model.jl b/src/model.jl index 09c0c1be1..1003efaf6 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1201,74 +1201,6 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end end -""" - generated_quantities(model::Model, chain::AbstractChains) - -Execute `model` for each of the samples in `chain` and return an array of the values -returned by the `model` for each sample. - -# Examples -## General -Often you might have additional quantities computed inside the model that you want to -inspect, e.g. -```julia -@model function demo(x) - # sample and observe - θ ~ Prior() - x ~ Likelihood() - return interesting_quantity(θ, x) -end -m = demo(data) -chain = sample(m, alg, n) -# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples -# from the posterior/`chain`: -generated_quantities(m, chain) # <= results in a `Vector` of returned values - # from `interesting_quantity(θ, x)` -``` -## Concrete (and simple) -```julia -julia> using DynamicPPL, Turing - -julia> @model function demo(xs) - s ~ InverseGamma(2, 3) - m_shifted ~ Normal(10, √s) - m = m_shifted - 10 - - for i in eachindex(xs) - xs[i] ~ Normal(m, √s) - end - - return (m, ) - end -demo (generic function with 1 method) - -julia> model = demo(randn(10)); - -julia> chain = sample(model, MH(), 10); - -julia> generated_quantities(model, chain) -10×1 Array{Tuple{Float64},2}: - (2.1964758025119338,) - (2.1964758025119338,) - (0.09270081916291417,) - (0.09270081916291417,) - (0.09270081916291417,) - (0.09270081916291417,) - (0.09270081916291417,) - (0.043088571494005024,) - (-0.16489786710222099,) - (-0.16489786710222099,) -``` -""" -function generated_quantities(model::Model, chain::AbstractChains) - varinfo = VarInfo(model) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - return map(iters) do (sample_idx, chain_idx) - setval_and_resample!(varinfo, chain, sample_idx, chain_idx) - model(varinfo) - end -end - """ generated_quantities(model::Model, parameters::NamedTuple) generated_quantities(model::Model, values, keys) @@ -1295,7 +1227,7 @@ demo (generic function with 2 methods) julia> model = demo(randn(10)); -julia> parameters = (; s = 1.0, m_shifted=10); +julia> parameters = (; s = 1.0, m_shifted=10.0); julia> generated_quantities(model, parameters) (0.0,) @@ -1305,13 +1237,10 @@ julia> generated_quantities(model, values(parameters), keys(parameters)) ``` """ function generated_quantities(model::Model, parameters::NamedTuple) - varinfo = VarInfo(model) - setval_and_resample!(varinfo, values(parameters), keys(parameters)) - return model(varinfo) + fixed_model = fix(model, parameters) + return fixed_model() end function generated_quantities(model::Model, values, keys) - varinfo = VarInfo(model) - setval_and_resample!(varinfo, values, keys) - return model(varinfo) + return generated_quantities(model, NamedTuple{keys}(values)) end diff --git a/src/sampler.jl b/src/sampler.jl index cfc58942e..833aaf7e2 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -150,7 +150,7 @@ function set_values!!( flattened_param_vals = varinfo[spl] length(flattened_param_vals) == length(initial_params) || throw( DimensionMismatch( - "Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(theta)))", + "Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(flattened_param_vals)))", ), ) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d8afb9cec..06a151f82 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -322,15 +322,17 @@ Base.getindex(vi::SimpleVarInfo, vn::VarName) = getindex_internal(vi, vn) function Base.getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) return map(Base.Fix1(getindex, vi), vns) end -# HACK: Needed to disambiguiate. +# HACK: Needed to disambiguate. Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) getindex_internal(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) # `AbstractDict` -function getindex_internal(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName) - return nested_getindex(vi.values, vn) +function getindex_internal( + vi::SimpleVarInfo{<:Union{AbstractDict,VarNamedVector}}, vn::VarName +) + return getvalue(vi.values, vn) end Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) @@ -399,14 +401,28 @@ end function BangBang.push!!( vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, - r, + value, dist::Distribution, gidset::Set{Selector}, ) - vi.values[vn] = r + vi.values[vn] = value return vi end +function BangBang.push!!( + vi::SimpleVarInfo{<:VarNamedVector}, + vn::VarName, + value, + dist::Distribution, + gidset::Set{Selector}, +) + # The semantics of push!! for SimpleVarInfo and VarNamedVector are different. For + # SimpleVarInfo, push!! allows the key to exist already, for VarNamedVector it does not. + # Hence we need to call update!! here, which has the same semantics as push!! does for + # SimpleVarInfo. + return Accessors.@set vi.values = update!!(vi.values, vn, value) +end + const SimpleOrThreadSafeSimple{T,V,C} = Union{ SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}} } @@ -456,6 +472,8 @@ function _subset(x::NamedTuple, vns) return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix1(getindex, x), syms))) end +_subset(x::VarNamedVector, vns) = subset(x, vns) + # `merge` function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) @@ -563,6 +581,9 @@ end function values_as(vi::SimpleVarInfo{<:AbstractDict}, ::Type{NamedTuple}) return NamedTuple((Symbol(k), v) for (k, v) in vi.values) end +function values_as(vi::SimpleVarInfo, ::Type{T}) where {T} + return values_as(vi.values, T) +end """ logjoint(model::Model, θ) @@ -708,3 +729,5 @@ end function ThreadSafeVarInfo(vi::SimpleVarInfo{<:Any,<:Ref}) return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()]) end + +has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector diff --git a/src/test_utils.jl b/src/test_utils.jl index 6f7481c40..9a606b4ef 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -37,20 +37,35 @@ function setup_varinfos( model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false ) # VarInfo - vi_untyped = VarInfo() - model(vi_untyped) - vi_typed = DynamicPPL.TypedVarInfo(vi_untyped) + vi_untyped_metadata = VarInfo(DynamicPPL.Metadata()) + vi_untyped_vnv = VarInfo(DynamicPPL.VarNamedVector()) + model(vi_untyped_metadata) + model(vi_untyped_vnv) + vi_typed_metadata = DynamicPPL.TypedVarInfo(vi_untyped_metadata) + vi_typed_vnv = DynamicPPL.TypedVarInfo(vi_untyped_vnv) + # SimpleVarInfo svi_typed = SimpleVarInfo(example_values) svi_untyped = SimpleVarInfo(OrderedDict()) + svi_vnv = SimpleVarInfo(VarNamedVector()) # SimpleVarInfo{<:Any,<:Ref} svi_typed_ref = SimpleVarInfo(example_values, Ref(getlogp(svi_typed))) svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped))) + svi_vnv_ref = SimpleVarInfo(VarNamedVector(), Ref(getlogp(svi_vnv))) - lp = getlogp(vi_typed) + lp = getlogp(vi_typed_metadata) varinfos = map(( - vi_untyped, vi_typed, svi_typed, svi_untyped, svi_typed_ref, svi_untyped_ref + vi_untyped_metadata, + vi_untyped_vnv, + vi_typed_metadata, + vi_typed_vnv, + svi_typed, + svi_untyped, + svi_vnv, + svi_typed_ref, + svi_untyped_ref, + svi_vnv_ref, )) do vi # Set them all to the same values. DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index f3bc84935..196f243bf 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -55,7 +55,7 @@ function setlogp!!(vi::ThreadSafeVarInfoWithRef, logp) return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), vi.logps) end -has_varnamevector(vi::DynamicPPL.ThreadSafeVarInfo) = has_varnamevector(vi.varinfo) +has_varnamedvector(vi::DynamicPPL.ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) function BangBang.push!!( vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} diff --git a/src/utils.jl b/src/utils.jl index 9ddeb6247..75bcef327 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -48,7 +48,7 @@ true i.e., regardless of whether you evaluate the log prior, the log likelihood or the log joint density. If you would like to avoid this behaviour you should check the evaluation context. It can be accessed with the internal variable `__context__`. - For instance, in the following example the log density is not accumulated when only the log prior is computed: + For instance, in the following example the log density is not accumulated when only the log prior is computed: ```jldoctest; setup = :(using Distributions) julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x); @@ -225,21 +225,30 @@ invlink_transform(dist) = inverse(link_transform(dist)) # Helper functions for vectorize/reconstruct values # ##################################################### -# Useful transformation going from the flattened representation. -struct FromVec{Size} <: Bijectors.Bijector +""" + ReshapeTransform(size::Size) + +A `Bijector` that transforms an `AbstractVector` to a realization of size `size`. As a +special case, if `size` is an empty tuple it transforms a singleton vector into a scalar. + +This transformation can be inverted by calling `tovec`. +""" +struct ReshapeTransform{Size} <: Bijectors.Bijector size::Size end -FromVec(x::Union{Real,AbstractArray}) = FromVec(size(x)) +ReshapeTransform(x::Union{Real,AbstractArray}) = ReshapeTransform(size(x)) # TODO: Should we materialize the `reshape`? -(f::FromVec)(x) = reshape(x, f.size) -(f::FromVec{Tuple{}})(x) = only(x) +(f::ReshapeTransform)(x::AbstractVector) = reshape(x, f.size) +(f::ReshapeTransform{Tuple{}})(x::AbstractVector) = only(x) # TODO: Specialize for `Tuple{<:Any}` since this correspond to a `Vector`. -Bijectors.with_logabsdet_jacobian(f::FromVec, x) = (f(x), 0) -# We want to use the inverse of `FromVec` so it preserves the size information. -Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:FromVec}, x) = (tovec(x), 0) +Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x) = (f(x), 0) +# We want to use the inverse of `ReshapeTransform` so it preserves the size information. +function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ReshapeTransform}, x) + return (tovec(x), 0) +end struct ToChol <: Bijectors.Bijector uplo::Char @@ -254,15 +263,16 @@ Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y::Cholesky) = Return the transformation from the vector representation of `x` to original representation. """ from_vec_transform(x::Union{Real,AbstractArray}) = from_vec_transform_for_size(size(x)) -from_vec_transform(C::Cholesky) = ToChol(C.uplo) ∘ FromVec(size(C.UL)) +from_vec_transform(C::Cholesky) = ToChol(C.uplo) ∘ ReshapeTransform(size(C.UL)) """ from_vec_transform_for_size(sz::Tuple) -Return the transformation from the vector representation of a realization of size `sz` to original representation. +Return the transformation from the vector representation of a realization of size `sz` to +original representation. """ -from_vec_transform_for_size(sz::Tuple) = FromVec(sz) -from_vec_transform_for_size(::Tuple{()}) = FromVec(()) +from_vec_transform_for_size(sz::Tuple) = ReshapeTransform(sz) +from_vec_transform_for_size(::Tuple{()}) = ReshapeTransform(()) from_vec_transform_for_size(::Tuple{<:Any}) = identity """ @@ -272,7 +282,7 @@ Return the transformation from the vector representation of a realization from distribution `dist` to the original representation compatible with `dist`. """ from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist)) -from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ∘ FromVec(size(dist)) +from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ∘ ReshapeTransform(size(dist)) """ from_vec_transform(f, size::Tuple) @@ -854,6 +864,7 @@ end Return type corresponding to `float(typeof(x))` if possible; otherwise return `Real`. """ float_type_with_fallback(::Type) = Real +float_type_with_fallback(::Type{Union{}}) = Real float_type_with_fallback(::Type{T}) where {T<:Real} = float(T) """ diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 52ba6eb61..c5003d53a 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -177,7 +177,7 @@ julia> # Approach 1: Convert back to constrained space using `invlink` and extra julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions # used in the very first model evaluation, hence the support of `y` # is not updated even though `x` has changed. - lb ≤ varinfo_invlinked[@varname(y)] ≤ ub + lb ≤ first(varinfo_invlinked[@varname(y)]) ≤ ub false julia> # Approach 2: Extract realizations using `values_as_in_model`. diff --git a/src/varinfo.jl b/src/varinfo.jl index 19c178a03..a6a5c0400 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -101,7 +101,7 @@ struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo logp::Base.RefValue{Tlogp} num_produce::Base.RefValue{Int} end -const VectorVarInfo = VarInfo{<:VarNameVector} +const VectorVarInfo = VarInfo{<:VarNamedVector} const UntypedVarInfo = VarInfo{<:Metadata} const TypedVarInfo = VarInfo{<:NamedTuple} const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ @@ -120,9 +120,9 @@ function VarInfo(old_vi::VarInfo, spl, x::AbstractVector) ) end -# No-op if we're already working with a `VarNameVector`. -metadata_to_varnamevector(vnv::VarNameVector) = vnv -function metadata_to_varnamevector(md::Metadata) +# No-op if we're already working with a `VarNamedVector`. +metadata_to_varnamedvector(vnv::VarNamedVector) = vnv +function metadata_to_varnamedvector(md::Metadata) idcs = copy(md.idcs) vns = copy(md.vns) ranges = copy(md.ranges) @@ -132,32 +132,32 @@ function metadata_to_varnamevector(md::Metadata) from_vec_transform(dist) end - return VarNameVector( + return VarNamedVector( OrderedDict{eltype(keys(idcs)),Int}(idcs), vns, ranges, vals, transforms ) end function VectorVarInfo(vi::UntypedVarInfo) - md = metadata_to_varnamevector(vi.metadata) + md = metadata_to_varnamedvector(vi.metadata) lp = getlogp(vi) return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) end function VectorVarInfo(vi::TypedVarInfo) - md = map(metadata_to_varnamevector, vi.metadata) + md = map(metadata_to_varnamedvector, vi.metadata) lp = getlogp(vi) return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) end """ - has_varnamevector(varinfo::VarInfo) + has_varnamedvector(varinfo::VarInfo) -Returns `true` if `varinfo` uses `VarNameVector` as metadata. +Returns `true` if `varinfo` uses `VarNamedVector` as metadata. """ -has_varnamevector(vi) = false -function has_varnamevector(vi::VarInfo) - return vi.metadata isa VarNameVector || - (vi isa TypedVarInfo && any(Base.Fix2(isa, VarNameVector), values(vi.metadata))) +has_varnamedvector(vi::AbstractVarInfo) = false +function has_varnamedvector(vi::VarInfo) + return vi.metadata isa VarNamedVector || + (vi isa TypedVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) end """ @@ -170,8 +170,9 @@ function untyped_varinfo( model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), + metadata_type::Type=VarNamedVector, ) - varinfo = VarInfo() + varinfo = VarInfo(metadata_type()) return last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context))) end function untyped_varinfo(model::Model, args::Union{AbstractSampler,AbstractContext}...) @@ -190,8 +191,9 @@ function VarInfo( model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), + metadata_type::Type=VarNamedVector, ) - return typed_varinfo(rng, model, sampler, context) + return typed_varinfo(rng, model, sampler, context, metadata_type) end VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...) @@ -300,6 +302,11 @@ function subset(varinfo::UntypedVarInfo, vns::AbstractVector{<:VarName}) return VarInfo(metadata, varinfo.logp, varinfo.num_produce) end +function subset(varinfo::VectorVarInfo, vns::AbstractVector{<:VarName}) + metadata = subset(varinfo.metadata, vns) + return VarInfo(metadata, varinfo.logp, varinfo.num_produce) +end + function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName{sym}}) where {sym} # If all the variables are using the same symbol, then we can just extract that field from the metadata. metadata = subset(getfield(varinfo.metadata, sym), vns) @@ -379,7 +386,7 @@ function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) ) end -function merge_metadata(vnv_left::VarNameVector, vnv_right::VarNameVector) +function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) return merge(vnv_left, vnv_right) end @@ -518,8 +525,6 @@ end const VarView = Union{Int,UnitRange,Vector{Int}} -getindex_internal(vi::UntypedVarInfo, vview::VarView) = view(vi.metadata.vals, vview) - """ setval!(vi::UntypedVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}}) @@ -576,12 +581,20 @@ Return the distribution from which `vn` was sampled in `vi`. getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn) getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)] # HACK: we shouldn't need this -getdist(::VarNameVector, ::VarName) = nothing +function getdist(::VarNamedVector, ::VarName) + throw(ErrorException("getdist does not exist for VarNamedVector")) +end getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, vn), vn) -getindex_internal(md::Metadata, vn::VarName) = view(md.vals, getrange(md, vn)) +# TODO(torfjelde): Use `view` instead of `getindex`. Requires addressing type-stability issues though, +# since then we might be returning a `SubArray` rather than an `Array`, which is typically +# what a bijector would result in, even if the input is a view (`SubArray`). +# TODO(torfjelde): An alternative is to implement `view` directly instead. +getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn)) # HACK: We shouldn't need this -getindex_internal(vnv::VarNameVector, vn::VarName) = view(vnv.vals, getrange(vnv, vn)) +# TODO(mhauru) This seems to return an array always for VarNamedVector, but a scalar for +# Metadata. What's the right thing to do here? +getindex_internal(vnv::VarNamedVector, vn::VarName) = getindex(vnv.vals, getrange(vnv, vn)) function getindex_internal(vi::VarInfo, vns::Vector{<:VarName}) return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns) @@ -601,9 +614,6 @@ end function setval!(md::Metadata, val, vn::VarName) return md.vals[getrange(md, vn)] = tovec(val) end -function setval!(vnv::VarNameVector, val, vn::VarName) - return setindex_raw!(vnv, tovec(val), vn) -end """ getall(vi::VarInfo) @@ -621,7 +631,7 @@ function getall(md::Metadata) Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0) ) end -getall(vnv::VarNameVector) = vnv.vals +getall(vnv::VarNamedVector) = vnv.vals """ setall!(vi::VarInfo, val) @@ -637,7 +647,7 @@ function _setall!(metadata::Metadata, val) metadata.vals[r] .= val[r] end end -function _setall!(vnv::VarNameVector, val) +function _setall!(vnv::VarNamedVector, val) # TODO: Do something more efficient here. for i in 1:length(vnv) vnv[i] = val[i] @@ -675,7 +685,10 @@ function settrans!!(metadata::Metadata, trans::Bool, vn::VarName) return metadata end -function settrans!!(vnv::VarNameVector, trans::Bool, vn::VarName) + +# TODO(mhauru) Isn't this infinite recursion? Shouldn't rather change the `transforms` +# field? +function settrans!!(vnv::VarNamedVector, trans::Bool, vn::VarName) settrans!(vnv, trans, vn) return vnv end @@ -759,7 +772,7 @@ end length(exprs) == 0 && return :(NamedTuple()) return :($(exprs...),) end -@inline function findinds(f_meta, s, ::Val{space}) where {space} +@inline function findinds(f_meta::Metadata, s, ::Val{space}) where {space} # Get all the idcs of the vns in `space` and that belong to the selector `s` return filter( (i) -> @@ -768,11 +781,27 @@ end 1:length(f_meta.gids), ) end -@inline function findinds(f_meta) +@inline function findinds(f_meta::Metadata) # Get all the idcs of the vns return filter((i) -> isempty(f_meta.gids[i]), 1:length(f_meta.gids)) end +function findinds(vnv::VarNamedVector, ::Selector, ::Val{space}) where {space} + # New Metadata objects are created with an empty list of gids, which is intrepreted as + # all Selectors applying to all variables. We assume the same behavior for + # VarNamedVector, and thus ignore the Selector argument. + if space !== () + msg = "VarNamedVector does not support selecting variables based on samplers" + throw(ErrorException(msg)) + else + return findinds(vnv) + end +end + +function findinds(vnv::VarNamedVector) + return 1:length(vnv.varnames) +end + # Get all vns of variables belonging to spl _getvns(vi::VarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) function _getvns(vi::VarInfo, spl::Union{SampleFromPrior,SampleFromUniform}) @@ -788,7 +817,7 @@ end @generated function _getvns(metadata, idcs::NamedTuple{names}) where {names} exprs = [] for f in names - push!(exprs, :($f = metadata.$f.vns[idcs.$f])) + push!(exprs, :($f = Base.keys(metadata.$f)[idcs.$f])) end length(exprs) == 0 && return :(NamedTuple()) return :($(exprs...),) @@ -848,13 +877,30 @@ function set_flag!(md::Metadata, vn::VarName, flag::String) return md.flags[flag][getidx(md, vn)] = true end +function set_flag!(vnv::VarNamedVector, ::VarName, flag::String) + if flag == "del" + # The "del" flag is effectively always set for a VarNamedVector, so this is a no-op. + else + throw(ErrorException("Flag $flag not valid for VarNamedVector")) + end + return vnv +end + #### #### APIs for typed and untyped VarInfo #### # VarInfo -VarInfo(meta=Metadata()) = VarInfo(meta, Ref{Float64}(0.0), Ref(0)) +VarInfo(meta=VarNamedVector()) = VarInfo(meta, Ref{Float64}(0.0), Ref(0)) + +function TypedVarInfo(vi::VectorVarInfo) + new_metas = group_by_symbol(vi.metadata) + logp = getlogp(vi) + num_produce = get_num_produce(vi) + nt = NamedTuple(new_metas) + return VarInfo(nt, Ref(logp), Ref(num_produce)) +end """ TypedVarInfo(vi::UntypedVarInfo) @@ -966,8 +1012,14 @@ end Add `gid` to the set of sampler selectors associated with `vn` in `vi`. """ -function setgid!(vi::VarInfo, gid::Selector, vn::VarName) - return push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) +setgid!(vi::VarInfo, gid::Selector, vn::VarName) = setgid!(getmetadata(vi, vn), gid, vn) + +function setgid!(m::Metadata, gid::Selector, vn::VarName) + return push!(m.gids[getidx(m, vn)], gid) +end + +function setgid!(vnv::VarNamedVector, gid::Selector, vn::VarName) + throw(ErrorException("Calling setgid! on a VarNamedVector isn't valid.")) end istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) @@ -1014,20 +1066,18 @@ and parameters sampled in `vi` to 0. """ reset_num_produce!(vi::VarInfo) = set_num_produce!(vi, 0) -isempty(vi::UntypedVarInfo) = isempty(vi.metadata.idcs) -isempty(vi::TypedVarInfo) = _isempty(vi.metadata) +# Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple). +isempty(vi::VarInfo) = _isempty(vi.metadata) +_isempty(metadata::Metadata) = isempty(metadata.idcs) +_isempty(vnv::VarNamedVector) = isempty(vnv) @generated function _isempty(metadata::NamedTuple{names}) where {names} - expr = Expr(:&&, :true) - for f in names - push!(expr.args, :(isempty(metadata.$f.idcs))) - end - return expr + return Expr(:&&, (:(isempty(metadata.$f)) for f in names)...) end # X -> R for all variables associated with given sampler function link!!(t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) - # If we're working with a `VarNameVector`, we always use immutable. - has_varnamevector(vi) && return link(t, vi, spl, model) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return link(t, vi, spl, model) # Call `_link!` instead of `link!` to avoid deprecation warning. _link!(vi, spl) return vi @@ -1070,10 +1120,8 @@ function _link!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) for vn in vns - dist = getdist(vi, vn) - _inner_transform!( - vi, vn, dist, internal_to_linked_internal_transform(vi, vn, dist) - ) + f = internal_to_linked_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) settrans!!(vi, true, vn) end else @@ -1100,13 +1148,8 @@ end if ~istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns - dist = getdist(vi, vn) - _inner_transform!( - vi, - vn, - dist, - internal_to_linked_internal_transform(vi, vn, dist), - ) + f = internal_to_linked_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) settrans!!(vi, true, vn) end else @@ -1123,8 +1166,8 @@ end function invlink!!( t::DynamicTransformation, vi::VarInfo, spl::AbstractSampler, model::Model ) - # If we're working with a `VarNameVector`, we always use immutable. - has_varnamevector(vi) && return invlink(t, vi, spl, model) + # If we're working with a `VarNamedVector`, we always use immutable. + has_varnamedvector(vi) && return invlink(t, vi, spl, model) # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. _invlink!(vi, spl) return vi @@ -1176,10 +1219,8 @@ function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if istrans(vi, vns[1]) for vn in vns - dist = getdist(vi, vn) - _inner_transform!( - vi, vn, dist, linked_internal_to_internal_transform(vi, vn, dist) - ) + f = linked_internal_to_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) settrans!!(vi, false, vn) end else @@ -1206,13 +1247,8 @@ end if istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns - dist = getdist(vi, vn) - _inner_transform!( - vi, - vn, - dist, - linked_internal_to_internal_transform(vi, vn, dist), - ) + f = linked_internal_to_internal_transform(vi, vn) + _inner_transform!(vi, vn, f) settrans!!(vi, false, vn) end else @@ -1225,11 +1261,11 @@ end return expr end -function _inner_transform!(vi::VarInfo, vn::VarName, dist, f) - return _inner_transform!(getmetadata(vi, vn), vi, vn, dist, f) +function _inner_transform!(vi::VarInfo, vn::VarName, f) + return _inner_transform!(getmetadata(vi, vn), vi, vn, f) end -function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, dist, f) +function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) # TODO: Use inplace versions to avoid allocations yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn)) # Determine the new range. @@ -1267,7 +1303,9 @@ function link( return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, spl, model) end -function _link(model::Model, varinfo::UntypedVarInfo, spl::AbstractSampler) +function _link( + model::Model, varinfo::Union{UntypedVarInfo,VectorVarInfo}, spl::AbstractSampler +) varinfo = deepcopy(varinfo) return VarInfo( _link_metadata!(model, varinfo, varinfo.metadata, _getvns_link(varinfo, spl)), @@ -1322,7 +1360,7 @@ function _link_metadata!(model::Model, varinfo::VarInfo, metadata::Metadata, tar yvec = tovec(y) # Accumulate the log-abs-det jacobian correction. acclogp!!(varinfo, -logjac) - # Mark as no longer transformed. + # Mark as transformed. settrans!!(varinfo, true, vn) # Return the vectorized transformed value. return yvec @@ -1351,14 +1389,14 @@ function _link_metadata!(model::Model, varinfo::VarInfo, metadata::Metadata, tar end function _link_metadata!( - model::Model, varinfo::VarInfo, metadata::VarNameVector, target_vns + model::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns ) # HACK: We ignore `target_vns` here. vns = keys(metadata) # Need to extract the priors from the model. dists = extract_priors(model, varinfo) - is_transformed = copy(metadata.is_transformed) + is_unconstrained = copy(metadata.is_unconstrained) # Construct the linking transformations. link_transforms = map(vns) do vn @@ -1368,13 +1406,12 @@ function _link_metadata!( end # Otherwise, we derive the transformation from the distribution. - is_transformed[getidx(metadata, vn)] = true + # TODO(mhauru) Could move the mutation outside of the map, just for style. + is_unconstrained[getidx(metadata, vn)] = true internal_to_linked_internal_transform(varinfo, vn, dists[vn]) end # Compute the transformed values. ys = map(vns, link_transforms) do vn, f - # TODO: Do we need to handle scenarios where `vn` is not in `dists`? - dist = dists[vn] x = getindex_internal(metadata, vn) y, logjac = with_logabsdet_jacobian(f, x) # Accumulate the log-abs-det jacobian correction. @@ -1400,13 +1437,13 @@ function _link_metadata!( end # Now we just create a new metadata with the new `vals` and `ranges`. - return VarNameVector( + return VarNamedVector( metadata.varname_to_index, metadata.varnames, ranges_new, reduce(vcat, yvecs), transforms, - is_transformed, + is_unconstrained, ) end @@ -1467,7 +1504,7 @@ function _invlink_metadata!(::Model, varinfo::VarInfo, metadata::Metadata, targe # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn # Return early if we're already in constrained space OR if we're not - # supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler. + # supposed to touch this `vn`, e.g. when `vn` does not belong to the current sampler. # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. if !istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] @@ -1511,14 +1548,14 @@ function _invlink_metadata!(::Model, varinfo::VarInfo, metadata::Metadata, targe end function _invlink_metadata!( - model::Model, varinfo::VarInfo, metadata::VarNameVector, target_vns + model::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns ) # HACK: We ignore `target_vns` here. # TODO: Make use of `update!` to aovid copying values. # => Only need to allocate for transformations. vns = keys(metadata) - is_transformed = copy(metadata.is_transformed) + is_unconstrained = copy(metadata.is_unconstrained) # Compute the transformed values. xs = map(vns) do vn @@ -1529,7 +1566,7 @@ function _invlink_metadata!( # Accumulate the log-abs-det jacobian correction. acclogp!!(varinfo, -logjac) # Mark as no longer transformed. - is_transformed[getidx(metadata, vn)] = false + is_unconstrained[getidx(metadata, vn)] = false # Return the transformed value. return x end @@ -1549,13 +1586,13 @@ function _invlink_metadata!( end # Now we just create a new metadata with the new `vals` and `ranges`. - return VarNameVector( + return VarNamedVector( metadata.varname_to_index, metadata.varnames, ranges_new, reduce(vcat, xvecs), transforms, - is_transformed, + is_unconstrained, ) end @@ -1564,9 +1601,9 @@ end Check whether `vi` is in the transformed space for a particular sampler `spl`. -Turing's Hamiltonian samplers use the `link` and `invlink` functions from +Turing's Hamiltonian samplers use the `link` and `invlink` functions from [Bijectors.jl](https://github.com/TuringLang/Bijectors.jl) to map a constrained variable -(for example, one bounded to the space `[0, 1]`) from its constrained space to the set of +(for example, one bounded to the space `[0, 1]`) from its constrained space to the set of real numbers. `islinked` checks if the number is in the constrained space or the real space. """ function islinked(vi::UntypedVarInfo, spl::Union{Sampler,SampleFromPrior}) @@ -1597,9 +1634,11 @@ function nested_setindex_maybe!( nothing end end -function _nested_setindex_maybe!(vi::VarInfo, md::Metadata, val, vn::VarName) +function _nested_setindex_maybe!( + vi::VarInfo, md::Union{Metadata,VarNamedVector}, val, vn::VarName +) # If `vn` is in `vns`, then we can just use the standard `setindex!`. - vns = md.vns + vns = Base.keys(md) if vn in vns setindex!(vi, val, vn) return vn @@ -1610,8 +1649,7 @@ function _nested_setindex_maybe!(vi::VarInfo, md::Metadata, val, vn::VarName) i === nothing && return nothing vn_parent = vns[i] - dist = getdist(md, vn_parent) - val_parent = getindex(vi, vn_parent, dist) # TODO: Ensure that we're working with a view here. + val_parent = getindex(vi, vn_parent) # TODO: Ensure that we're working with a view here. # Split the varname into its tail optic. optic = remove_parent_optic(vn_parent, vn) # Update the value for the parent. @@ -1622,30 +1660,47 @@ end # The default getindex & setindex!() for get & set values # NOTE: vi[vn] will always transform the variable to its original space and Julia type -getindex(vi::VarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) +function getindex(vi::VarInfo, vn::VarName) + return from_maybe_linked_internal_transform(vi, vn)(getindex_internal(vi, vn)) +end + function getindex(vi::VarInfo, vn::VarName, dist::Distribution) @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" val = getindex_internal(vi, vn) return from_maybe_linked_internal(vi, vn, dist, val) end -# HACK: Allows us to also work with `VarNameVector` where `dist` is not used, -# but we instead use a transformation stored with the variable. -function getindex(vi::VarInfo, vn::VarName, ::Nothing) - if !haskey(vi, vn) - throw(KeyError(vn)) - end - return getmetadata(vi, vn)[vn] -end function getindex(vi::VarInfo, vns::Vector{<:VarName}) - vals_linked = mapreduce(vcat, vns) do vn - getindex(vi, vn) + vals = map(vn -> getindex(vi, vn), vns) + + et = eltype(vals) + # This will catch type unstable cases, where vals has mixed types. + if !isconcretetype(et) + throw(ArgumentError("All variables must have the same type.")) + end + + if et <: Vector + all_of_equal_dimension = all(x -> length(x) == length(vals[1]), vals) + if !all_of_equal_dimension + throw(ArgumentError("All variables must have the same dimension.")) + end + end + + # TODO(mhauru) I'm not very pleased with the return type varying like this, even though + # this should be type stable. + vec_vals = reduce(vcat, vals) + if et <: Vector + # The individual variables are multivariate, and thus we return the values as a + # matrix. + return reshape(vec_vals, (:, length(vns))) + else + # The individual variables are univariate, and thus we return a vector of scalars. + return vec_vals end - # HACK: I don't like this. - dist = getdist(vi, vns[1]) - return recombine(dist, vals_linked, length(vns)) end + function getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) + # TODO(mhauru) Does this ever get called? @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" vals_linked = mapreduce(vcat, vns) do vn getindex(vi, vn, dist) @@ -1828,11 +1883,6 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce) return meta end -function Base.push!(vnv::VarNameVector, vn, r, dist, gidset, num_produce) - f = from_vec_transform(dist) - return push!(vnv, vn, r, f) -end - """ setorder!(vi::VarInfo, vn::VarName, index::Int) @@ -1847,7 +1897,7 @@ function setorder!(metadata::Metadata, vn::VarName, index::Int) metadata.orders[metadata.idcs[vn]] = index return metadata end -setorder!(vnv::VarNameVector, ::VarName, ::Int) = vnv +setorder!(vnv::VarNamedVector, ::VarName, ::Int) = vnv """ getorder(vi::VarInfo, vn::VarName) @@ -1873,23 +1923,44 @@ end function is_flagged(metadata::Metadata, vn::VarName, flag::String) return metadata.flags[flag][getidx(metadata, vn)] end -# HACK: This is bad. Should we always return `true` here? -is_flagged(::VarNameVector, ::VarName, flag::String) = flag == "del" ? true : false +function is_flagged(::VarNamedVector, ::VarName, flag::String) + if flag == "del" + return true + else + throw(ErrorException("Flag $flag not valid for VarNamedVector")) + end +end +# TODO(mhauru) The "ignorable" argument is a temporary hack while developing VarNamedVector, +# but still having to support the interface based on Metadata too """ - unset_flag!(vi::VarInfo, vn::VarName, flag::String) + unset_flag!(vi::VarInfo, vn::VarName, flag::String, ignorable::Bool=false Set `vn`'s value for `flag` to `false` in `vi`. + +If `ignorable` is `false`, as it is by default, then this will error if setting the flag is +not possible. """ -function unset_flag!(vi::VarInfo, vn::VarName, flag::String) - unset_flag!(getmetadata(vi, vn), vn, flag) +function unset_flag!(vi::VarInfo, vn::VarName, flag::String, ignorable::Bool=false) + unset_flag!(getmetadata(vi, vn), vn, flag, ignorable) return vi end -function unset_flag!(metadata::Metadata, vn::VarName, flag::String) +function unset_flag!(metadata::Metadata, vn::VarName, flag::String, ignorable::Bool=false) metadata.flags[flag][getidx(metadata, vn)] = false return metadata end -unset_flag!(vnv::VarNameVector, ::VarName, ::String) = vnv + +function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String, ignorable::Bool=false) + if ignorable + return vnv + end + if flag == "del" + throw(ErrorException("The \"del\" flag cannot be unset for VarNamedVector")) + else + throw(ErrorException("Flag $flag not valid for VarNamedVector")) + end + return vnv +end """ set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler) @@ -1995,7 +2066,7 @@ end ) where {names} updates = map(names) do n quote - for vn in metadata.$n.vns + for vn in Base.keys(metadata.$n) indices_found = kernel!(vi, vn, values, keys_strings) if indices_found !== nothing num_indices_seen += length(indices_found) @@ -2077,14 +2148,6 @@ julia> DynamicPPL.setval!(var_info, (m = 100.0, )); # set `m` and and keep `x[1] julia> var_info[@varname(m)] # [✓] changed 100.0 -julia> var_info[@varname(x[1])] # [✓] unchanged --0.22312984965118443 - -julia> m(rng, var_info); # rerun model - -julia> var_info[@varname(m)] # [✓] unchanged -100.0 - julia> var_info[@varname(x[1])] # [✓] unchanged -0.22312984965118443 ``` @@ -2136,7 +2199,7 @@ julia> rng = StableRNG(42); julia> m = demo([missing]); -julia> var_info = DynamicPPL.VarInfo(rng, m); +julia> var_info = DynamicPPL.VarInfo(rng, m, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata); julia> var_info[@varname(m)] -0.6702516921145671 @@ -2234,6 +2297,9 @@ function values_as( return ConstructionBase.constructorof(D)(iter) end +values_as(vi::VectorVarInfo, args...) = values_as(vi.metadata, args...) +values_as(vi::VectorVarInfo, T::Type{Vector}) = values_as(vi.metadata, T) + function values_from_metadata(md::Metadata) return ( # `copy` to avoid accidentally mutation of internal representation. @@ -2243,7 +2309,7 @@ function values_from_metadata(md::Metadata) ) end -values_from_metadata(md::VarNameVector) = pairs(md) +values_from_metadata(md::VarNamedVector) = pairs(md) # Transforming from internal representation to distribution representation. # Without `dist` argument: base on `dist` extracted from self. @@ -2253,7 +2319,7 @@ end function from_internal_transform(md::Metadata, vn::VarName) return from_internal_transform(md, vn, getdist(md, vn)) end -function from_internal_transform(md::VarNameVector, vn::VarName) +function from_internal_transform(md::VarNamedVector, vn::VarName) return gettransform(md, vn) end # With both `vn` and `dist` arguments: base on provided `dist`. @@ -2261,7 +2327,7 @@ function from_internal_transform(vi::VarInfo, vn::VarName, dist) return from_internal_transform(getmetadata(vi, vn), vn, dist) end from_internal_transform(::Metadata, ::VarName, dist) = from_vec_transform(dist) -function from_internal_transform(::VarNameVector, ::VarName, dist) +function from_internal_transform(::VarNamedVector, ::VarName, dist) return from_vec_transform(dist) end @@ -2272,7 +2338,7 @@ end function from_linked_internal_transform(md::Metadata, vn::VarName) return from_linked_internal_transform(md, vn, getdist(md, vn)) end -function from_linked_internal_transform(md::VarNameVector, vn::VarName) +function from_linked_internal_transform(md::VarNamedVector, vn::VarName) return gettransform(md, vn) end # With both `vn` and `dist` arguments: base on provided `dist`. @@ -2283,6 +2349,6 @@ end function from_linked_internal_transform(::Metadata, ::VarName, dist) return from_linked_vec_transform(dist) end -function from_linked_internal_transform(::VarNameVector, ::VarName, dist) +function from_linked_internal_transform(::VarNamedVector, ::VarName, dist) return from_linked_vec_transform(dist) end diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl new file mode 100644 index 000000000..e9a7126ae --- /dev/null +++ b/src/varnamedvector.jl @@ -0,0 +1,1217 @@ +""" + VarNamedVector + +A container that stores values in a vectorised form, but indexable by variable names. + +When indexed by integers or `Colon`s, e.g. `vnv[2]` or `vnv[:]`, `VarNamedVector` behaves +like a `Vector`, and returns the values as they are stored. The stored form is always +vectorised, for instance matrix variables have been flattened, and may be further +transformed to achieve linking. + +When indexed by `VarName`s, e.g. `vnv[@varname(x)]`, `VarNamedVector` returns the values +in the original space. For instance, a linked matrix variable is first inverse linked and +then reshaped to its original form before returning it to the caller. + +`VarNamedVector` also stores a boolean for whether a variable has been transformed to +unconstrained Euclidean space or not. + +# Fields +$(FIELDS) +""" +struct VarNamedVector{ + K<:VarName,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector +} + """ + mapping from a `VarName` to its integer index in `varnames`, `ranges` and `transforms` + """ + varname_to_index::OrderedDict{K,Int} + + """ + vector of `VarNames` for the variables, where `varnames[varname_to_index[vn]] == vn` + """ + varnames::TVN # AbstractVector{<:VarName} + + """ + vector of index ranges in `vals` corresponding to `varnames`; each `VarName` `vn` has + a single index or a set of contiguous indices, such that the values of `vn` can be found + at `vals[ranges[varname_to_index[vn]]]` + """ + ranges::Vector{UnitRange{Int}} + + """ + vector of values of all variables; the value(s) of `vn` is/are + `vals[ranges[varname_to_index[vn]]]` + """ + vals::TVal # AbstractVector{<:Real} + + """ + vector of transformations, so that `transforms[varname_to_index[vn]]` is a callable + that transformes the value of `vn` back to its original space, undoing any linking and + vectorisation + """ + transforms::TTrans + + """ + vector of booleans indicating whether a variable has been transformed to unconstrained + Euclidean space or not, i.e. whether its domain is all of `ℝ^ⁿ`. Having + `is_unconstrained[varname_to_index[vn]] == false` does not necessarily mean that a + variable is constrained, but rather that it's not guaranteed to not be. + """ + is_unconstrained::BitVector + + """ + mapping from a variable index to the number of inactive entries for that variable. + Inactive entries are elements in `vals` that are not part of the value of any variable. + They arise when transformations change the dimension of the value stored. In active + entries always come after the last active entry for the given variable. + """ + num_inactive::OrderedDict{Int,Int} + + function VarNamedVector( + varname_to_index, + varnames::TVN, + ranges, + vals::TVal, + transforms::TTrans, + is_unconstrained, + num_inactive, + ) where {K,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector} + if length(varnames) != length(ranges) || + length(varnames) != length(transforms) || + length(varnames) != length(is_unconstrained) || + length(varnames) != length(varname_to_index) + msg = """ + Inputs to VarNamedVector have inconsistent lengths. Got lengths \ + varnames: $(length(varnames)), \ + ranges: $(length(ranges)), \ + transforms: $(length(transforms)), \ + is_unconstrained: $(length(is_unconstrained)), \ + varname_to_index: $(length(varname_to_index)).""" + throw(ArgumentError(msg)) + end + + num_vals = mapreduce(length, (+), ranges; init=0) + sum(values(num_inactive)) + if num_vals != length(vals) + msg = """ + The total number of elements in `vals` does not match the sum of the \ + lengths of the ranges and the number of inactive entries.""" + throw(ArgumentError(msg)) + end + + if Set(values(varname_to_index)) != Set(1:length(varnames)) + msg = "The values of `varname_to_index` are not valid indices." + throw(ArgumentError(msg)) + end + + if !issubset(Set(keys(num_inactive)), Set(values(varname_to_index))) + msg = "The keys of `num_inactive` are not valid indices." + throw(ArgumentError(msg)) + end + + # Check that the varnames don't overlap. The time cost is quadratic in number of + # variables. If this ever becomes an issue, we should be able to go down to at least + # N log N by sorting based on subsumes-order. + for vn1 in keys(varname_to_index) + for vn2 in keys(varname_to_index) + vn1 === vn2 && continue + if subsumes(vn1, vn2) + msg = """ + Variables in a VarNamedVector should not subsume each other, \ + but $vn1 subsumes $vn2""" + throw(ArgumentError(msg)) + end + end + end + + # We could also have a test to check that the ranges don't overlap, but that sounds + # unlikely to occur, and implementing it in linear time would require a tiny bit of + # thought. + + return new{K,V,TVN,TVal,TTrans}( + varname_to_index, + varnames, + ranges, + vals, + transforms, + is_unconstrained, + num_inactive, + ) + end +end + +# Default values for is_unconstrained (all false) and num_inactive (empty). +function VarNamedVector( + varname_to_index, + varnames, + ranges, + vals, + transforms, + is_unconstrained=fill!(BitVector(undef, length(varnames)), 0), +) + return VarNamedVector( + varname_to_index, + varnames, + ranges, + vals, + transforms, + is_unconstrained, + OrderedDict{Int,Int}(), + ) +end + +# TODO(mhauru) Are we sure we want the last one to be of type Any[]? Might this call +# unnecessary type instability? +function VarNamedVector{K,V}() where {K,V} + return VarNamedVector(OrderedDict{K,Int}(), K[], UnitRange{Int}[], V[], Any[]) +end + +# TODO(mhauru) I would like for this to be VarNamedVector(Union{}, Union{}). This would +# allow expanding the VarName and element types only as necessary, which would help keep +# them concrete. However, making that change here opens some other cans of worms related to +# how VarInfo uses BangBang, that I don't want to deal with right now. +VarNamedVector() = VarNamedVector{VarName,Real}() +VarNamedVector(xs::Pair...) = VarNamedVector(OrderedDict(xs...)) +VarNamedVector(x::AbstractDict) = VarNamedVector(keys(x), values(x)) +function VarNamedVector(varnames, vals) + return VarNamedVector(collect_maybe(varnames), collect_maybe(vals)) +end +function VarNamedVector( + varnames::AbstractVector, vals::AbstractVector, transforms=map(from_vec_transform, vals) +) + # Convert `vals` into a vector of vectors. + vals_vecs = map(tovec, vals) + + # TODO: Is this really the way to do this? + if !(eltype(varnames) <: VarName) + varnames = convert(Vector{VarName}, varnames) + end + varname_to_index = OrderedDict{eltype(varnames),Int}( + vn => i for (i, vn) in enumerate(varnames) + ) + vals = reduce(vcat, vals_vecs) + # Make the ranges. + ranges = Vector{UnitRange{Int}}() + offset = 0 + for x in vals_vecs + r = (offset + 1):(offset + length(x)) + push!(ranges, r) + offset = r[end] + end + + return VarNamedVector(varname_to_index, varnames, ranges, vals, transforms) +end + +function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector) + return vnv_left.varname_to_index == vnv_right.varname_to_index && + vnv_left.varnames == vnv_right.varnames && + vnv_left.ranges == vnv_right.ranges && + vnv_left.vals == vnv_right.vals && + vnv_left.transforms == vnv_right.transforms && + vnv_left.is_unconstrained == vnv_right.is_unconstrained && + vnv_left.num_inactive == vnv_right.num_inactive +end + +# Some `VarNamedVector` specific functions. +getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn] + +getrange(vnv::VarNamedVector, idx::Int) = vnv.ranges[idx] +getrange(vnv::VarNamedVector, vn::VarName) = getrange(vnv, getidx(vnv, vn)) + +gettransform(vnv::VarNamedVector, idx::Int) = vnv.transforms[idx] +gettransform(vnv::VarNamedVector, vn::VarName) = gettransform(vnv, getidx(vnv, vn)) + +# TODO(mhauru) Eventually I would like to rename the istrans function to is_unconstrained, +# but that's significantly breaking. +""" + istrans(vnv::VarNamedVector, vn::VarName) + +Return a boolean for whether `vn` is guaranteed to have been transformed so that all of +Euclidean space is its domain. +""" +istrans(vnv::VarNamedVector, vn::VarName) = vnv.is_unconstrained[getidx(vnv, vn)] + +""" + settrans!(vnv::VarNamedVector, val::Bool, vn::VarName) + +Set the value for whether `vn` is guaranteed to have been transformed so that all of +Euclidean space is its domain. +""" +function settrans!(vnv::VarNamedVector, val::Bool, vn::VarName) + return vnv.is_unconstrained[vnv.varname_to_index[vn]] = val +end + +""" + has_inactive(vnv::VarNamedVector) + +Returns `true` if `vnv` has inactive ranges. +""" +has_inactive(vnv::VarNamedVector) = !isempty(vnv.num_inactive) + +""" + num_inactive(vnv::VarNamedVector) + +Return the number of inactive entries in `vnv`. +""" +num_inactive(vnv::VarNamedVector) = sum(values(vnv.num_inactive)) + +""" + num_inactive(vnv::VarNamedVector, vn::VarName) + +Returns the number of inactive entries for `vn` in `vnv`. +""" +num_inactive(vnv::VarNamedVector, vn::VarName) = num_inactive(vnv, getidx(vnv, vn)) +num_inactive(vnv::VarNamedVector, idx::Int) = get(vnv.num_inactive, idx, 0) + +""" + num_allocated(vnv::VarNamedVector) + +Returns the number of allocated entries in `vnv`, both active and inactive. +""" +num_allocated(vnv::VarNamedVector) = length(vnv.vals) + +""" + num_allocated(vnv::VarNamedVector, vn::VarName) + +Returns the number of allocated entries for `vn` in `vnv`, both active and inactive. +""" +num_allocated(vnv::VarNamedVector, vn::VarName) = num_allocated(vnv, getidx(vnv, vn)) +function num_allocated(vnv::VarNamedVector, idx::Int) + return length(getrange(vnv, idx)) + num_inactive(vnv, idx) +end + +# Basic array interface. +Base.eltype(vnv::VarNamedVector) = eltype(vnv.vals) +Base.length(vnv::VarNamedVector) = + if !has_inactive(vnv) + length(vnv.vals) + else + sum(length, vnv.ranges) + end +Base.size(vnv::VarNamedVector) = (length(vnv),) +Base.isempty(vnv::VarNamedVector) = isempty(vnv.varnames) + +# TODO: We should probably remove this +Base.IndexStyle(::Type{<:VarNamedVector}) = IndexLinear() + +# Dictionary interface. +Base.keys(vnv::VarNamedVector) = vnv.varnames +Base.values(vnv::VarNamedVector) = Iterators.map(Base.Fix1(getindex, vnv), vnv.varnames) +Base.pairs(vnv::VarNamedVector) = (vn => vnv[vn] for vn in keys(vnv)) + +Base.haskey(vnv::VarNamedVector, vn::VarName) = haskey(vnv.varname_to_index, vn) + +# `getindex` & `setindex!` +Base.getindex(vnv::VarNamedVector, i::Int) = getindex_raw(vnv, i) +function Base.getindex(vnv::VarNamedVector, vn::VarName) + x = getindex_raw(vnv, vn) + f = gettransform(vnv, vn) + return f(x) +end + +""" + find_containing_range(ranges::AbstractVector{<:AbstractRange}, x) + +Find the first range in `ranges` that contains `x`. + +Throw an `ArgumentError` if `x` is not in any of the ranges. +""" +function find_containing_range(ranges::AbstractVector{<:AbstractRange}, x) + # TODO: Assume `ranges` to be sorted and contiguous, and use `searchsortedfirst` + # for a more efficient approach. + range_idx = findfirst(Base.Fix1(∈, x), ranges) + + # If we're out of bounds, we raise an error. + if range_idx === nothing + throw(ArgumentError("Value $x is not in any of the ranges.")) + end + + return range_idx +end + +""" + adjusted_ranges(vnv::VarNamedVector) + +Return what `vnv.ranges` would be if there were no inactive entries. +""" +function adjusted_ranges(vnv::VarNamedVector) + # Every range following inactive entries needs to be shifted. + offset = 0 + ranges_adj = similar(vnv.ranges) + for (idx, r) in enumerate(vnv.ranges) + # Remove the `offset` in `r` due to inactive entries. + ranges_adj[idx] = r .- offset + # Update `offset`. + offset += get(vnv.num_inactive, idx, 0) + end + + return ranges_adj +end + +""" + index_to_vals_index(vnv::VarNamedVector, i::Int) + +Convert an integer index that ignores inactive entries to an index that accounts for them. + +This is needed when the user wants to index `vnv` like a vector, but shouldn't have to care +about inactive entries in `vnv.vals`. +""" +function index_to_vals_index(vnv::VarNamedVector, i::Int) + # If we don't have any inactive entries, there's nothing to do. + has_inactive(vnv) || return i + + # Get the adjusted ranges. + ranges_adj = adjusted_ranges(vnv) + # Determine the adjusted range that the index corresponds to. + r_idx = find_containing_range(ranges_adj, i) + r = vnv.ranges[r_idx] + # Determine how much of the index `i` is used to get to this range. + i_used = r_idx == 1 ? 0 : sum(length, ranges_adj[1:(r_idx - 1)]) + # Use remainder to index into `r`. + i_remainder = i - i_used + return r[i_remainder] +end + +""" + getindex_raw(vnv::VarNamedVector, i::Int) + getindex_raw(vnv::VarNamedVector, vn::VarName) + +Like `getindex`, but returns the values as they are stored in `vnv` without transforming. + +For integer indices this is the same as `getindex`, but for `VarName`s this is different. +""" +getindex_raw(vnv::VarNamedVector, i::Int) = vnv.vals[index_to_vals_index(vnv, i)] +getindex_raw(vnv::VarNamedVector, vn::VarName) = vnv.vals[getrange(vnv, vn)] + +# `getindex` for `Colon` +function Base.getindex(vnv::VarNamedVector, ::Colon) + return if has_inactive(vnv) + mapreduce(Base.Fix1(getindex, vnv.vals), vcat, vnv.ranges) + else + vnv.vals + end +end + +getindex_raw(vnv::VarNamedVector, ::Colon) = getindex(vnv, Colon()) + +# TODO(mhauru): Remove this as soon as possible. Only needed because of the old Gibbs +# sampler. +function Base.getindex(vnv::VarNamedVector, spl::AbstractSampler) + throw(ErrorException("Cannot index a VarNamedVector with a sampler.")) +end + +Base.setindex!(vnv::VarNamedVector, val, i::Int) = setindex_raw!(vnv, val, i) +function Base.setindex!(vnv::VarNamedVector, val, vn::VarName) + f = inverse(gettransform(vnv, vn)) + return setindex_raw!(vnv, f(val), vn) +end + +""" + setindex_raw!(vnv::VarNamedVector, val, i::Int) + setindex_raw!(vnv::VarNamedVector, val, vn::VarName) + +Like `setindex!`, but sets the values as they are stored in `vnv` without transforming. + +For integer indices this is the same as `setindex!`, but for `VarName`s this is different. +""" +function setindex_raw!(vnv::VarNamedVector, val, i::Int) + return vnv.vals[index_to_vals_index(vnv, i)] = val +end + +function setindex_raw!(vnv::VarNamedVector, val::AbstractVector, vn::VarName) + return vnv.vals[getrange(vnv, vn)] = val +end + +function Base.empty!(vnv::VarNamedVector) + # TODO: Or should the semantics be different, e.g. keeping `varnames`? + empty!(vnv.varname_to_index) + empty!(vnv.varnames) + empty!(vnv.ranges) + empty!(vnv.vals) + empty!(vnv.transforms) + empty!(vnv.is_unconstrained) + empty!(vnv.num_inactive) + return nothing +end +BangBang.empty!!(vnv::VarNamedVector) = (empty!(vnv); return vnv) + +""" + replace_values(vnv::VarNamedVector, vals::AbstractVector) + +Replace the values in `vnv` with `vals`, as they are stored internally. + +This is useful when we want to update the entire underlying vector of values in one go or if +we want to change the how the values are stored, e.g. alter the `eltype`. + +!!! warning + This replaces the raw underlying values, and so care should be taken when using this + function. For example, if `vnv` has any inactive entries, then the provided `vals` + should also contain the inactive entries to avoid unexpected behavior. + +# Examples + +```jldoctest varnamedvector-replace-values +julia> using DynamicPPL: VarNamedVector, replace_values + +julia> vnv = VarNamedVector(@varname(x) => [1.0]); + +julia> replace_values(vnv, [2.0])[@varname(x)] == [2.0] +true +``` + +This is also useful when we want to differentiate wrt. the values using automatic +differentiation, e.g. ForwardDiff.jl. + +```jldoctest varnamedvector-replace-values +julia> using ForwardDiff: ForwardDiff + +julia> f(x) = sum(abs2, replace_values(vnv, x)[@varname(x)]) +f (generic function with 1 method) + +julia> ForwardDiff.gradient(f, [1.0]) +1-element Vector{Float64}: + 2.0 +``` +""" +replace_values(vnv::VarNamedVector, vals) = Accessors.@set vnv.vals = vals + +# TODO(mhauru) The space argument is used by the old Gibbs sampler. To be removed. +function replace_values(vnv::VarNamedVector, ::Val{space}, vals) where {space} + if length(space) > 0 + msg = "Selecting values in a VarNamedVector with a space is not supported." + throw(ArgumentError(msg)) + end + return replace_values(vnv, vals) +end + +""" + unflatten(vnv::VarNamedVector, vals::AbstractVector) + +Return a new instance of `vnv` with the values of `vals` assigned to the variables. + +This assumes that `vals` have been transformed by the same transformations that that the +values in `vnv` have been transformed by. However, unlike [`replace_values`](@ref), +`unflatten` does account for inactive entries in `vnv`, so that the user does not have to +care about them. + +This is in a sense the reverse operation of `vnv[:]`. + +Unflatten recontiguifies the internal storage, getting rid of any inactive entries. + +# Examples + +```jldoctest varnamedvector-unflatten +julia> using DynamicPPL: VarNamedVector, unflatten + +julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(y) => [3.0]); + +julia> unflatten(vnv, vnv[:]) == vnv +true +""" +function unflatten(vnv::VarNamedVector, vals::AbstractVector) + new_ranges = deepcopy(vnv.ranges) + recontiguify_ranges!(new_ranges) + return VarNamedVector( + vnv.varname_to_index, vnv.varnames, new_ranges, vals, vnv.transforms + ) +end + +function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) + # Return early if possible. + isempty(left_vnv) && return deepcopy(right_vnv) + isempty(right_vnv) && return deepcopy(left_vnv) + + # Determine varnames. + vns_left = left_vnv.varnames + vns_right = right_vnv.varnames + vns_both = union(vns_left, vns_right) + + # Determine `eltype` of `vals`. + T_left = eltype(left_vnv.vals) + T_right = eltype(right_vnv.vals) + T = promote_type(T_left, T_right) + + # Determine `eltype` of `varnames`. + V_left = eltype(left_vnv.varnames) + V_right = eltype(right_vnv.varnames) + V = promote_type(V_left, V_right) + if !(V <: VarName) + V = VarName + end + + # Determine `eltype` of `transforms`. + F_left = eltype(left_vnv.transforms) + F_right = eltype(right_vnv.transforms) + F = promote_type(F_left, F_right) + + # Allocate. + varname_to_index = OrderedDict{V,Int}() + ranges = UnitRange{Int}[] + vals = T[] + transforms = F[] + is_unconstrained = BitVector(undef, length(vns_both)) + + # Range offset. + offset = 0 + + for (idx, vn) in enumerate(vns_both) + varname_to_index[vn] = idx + # Extract the necessary information from `left` or `right`. + if vn in vns_left && !(vn in vns_right) + # `vn` is only in `left`. + val = getindex_raw(left_vnv, vn) + f = gettransform(left_vnv, vn) + is_unconstrained[idx] = istrans(left_vnv, vn) + else + # `vn` is either in both or just `right`. + # Note that in a `merge` the right value has precedence. + val = getindex_raw(right_vnv, vn) + f = gettransform(right_vnv, vn) + is_unconstrained[idx] = istrans(right_vnv, vn) + end + n = length(val) + r = (offset + 1):(offset + n) + # Update. + append!(vals, val) + push!(ranges, r) + push!(transforms, f) + # Increment `offset`. + offset += n + end + + return VarNamedVector( + varname_to_index, vns_both, ranges, vals, transforms, is_unconstrained + ) +end + +""" + subset(vnv::VarNamedVector, vns::AbstractVector{<:VarName}) + +Return a new `VarNamedVector` containing the values from `vnv` for variables in `vns`. + +Which variables to include is determined by the `VarName`'s `subsumes` relation, meaning +that e.g. `subset(vnv, [@varname(x)])` will include variables like `@varname(x.a[1])`. + +# Examples + +```jldoctest varnamedvector-subset +julia> using DynamicPPL: VarNamedVector, @varname, subset + +julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(y) => [3.0]); + +julia> subset(vnv, [@varname(x)]) == VarNamedVector(@varname(x) => [1.0, 2.0]) +true + +julia> subset(vnv, [@varname(x[2])]) == VarNamedVector(@varname(x[2]) => [2.0]) +true +""" +function subset(vnv::VarNamedVector, vns_given::AbstractVector{VN}) where {VN<:VarName} + # NOTE: This does not specialize types when possible. + vns = mapreduce(vcat, vns_given; init=VN[]) do vn + filter(Base.Fix1(subsumes, vn), vnv.varnames) + end + vnv_new = similar(vnv) + # Return early if possible. + isempty(vnv) && return vnv_new + + for vn in vns + push!(vnv_new, vn, getindex_raw(vnv, vn), gettransform(vnv, vn)) + settrans!(vnv_new, istrans(vnv, vn), vn) + end + + return vnv_new +end + +""" + similar(vnv::VarNamedVector) + +Return a new `VarNamedVector` with the same structure as `vnv`, but with empty values. + +In this respect `vnv` behaves more like a dictionary than an array: `similar(vnv)` will +be entirely empty, rather than have `undef` values in it. + +# Examples + +```julia-doctest-varnamedvector-similar +julia> using DynamicPPL: VarNamedVector, @varname, similar + +julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(x[3]) => [3.0]); + +julia> similar(vnv) == VarNamedVector{VarName{:x}, Float64}() +true +""" +function Base.similar(vnv::VarNamedVector) + # NOTE: Whether or not we should empty the underlying containers or not + # is somewhat ambiguous. For example, `similar(vnv.varname_to_index)` will + # result in an empty `AbstractDict`, while the vectors, e.g. `vnv.ranges`, + # will result in non-empty vectors but with entries as `undef`. But it's + # much easier to write the rest of the code assuming that `undef` is not + # present, and so for now we empty the underlying containers, thus differing + # from the behavior of `similar` for `AbstractArray`s. + return VarNamedVector( + similar(vnv.varname_to_index), + similar(vnv.varnames, 0), + similar(vnv.ranges, 0), + similar(vnv.vals, 0), + similar(vnv.transforms, 0), + BitVector(), + similar(vnv.num_inactive), + ) +end + +""" + is_contiguous(vnv::VarNamedVector) + +Returns `true` if the underlying data of `vnv` is stored in a contiguous array. + +This is equivalent to negating [`has_inactive(vnv)`](@ref). +""" +is_contiguous(vnv::VarNamedVector) = !has_inactive(vnv) + +""" + nextrange(vnv::VarNamedVector, x) + +Return the range of `length(x)` from the end of current data in `vnv`. +""" +function nextrange(vnv::VarNamedVector, x) + offset = length(vnv.vals) + return (offset + 1):(offset + length(x)) +end + +""" + push!(vnv::VarNamedVector, vn::VarName, val[, transform]) + +Add a variable with given value to `vnv`. + +By default `transform` is the one that converts the value to a vector, which is how it is +stored in `vnv`. +""" +function Base.push!( + vnv::VarNamedVector, vn::VarName, val, transform=from_vec_transform(val) +) + # Error if we already have the variable. + haskey(vnv, vn) && throw(ArgumentError("variable name $vn already exists")) + # NOTE: We need to compute the `nextrange` BEFORE we start mutating the underlying + # storage. + val_vec = tovec(val) + r_new = nextrange(vnv, val_vec) + vnv.varname_to_index[vn] = length(vnv.varname_to_index) + 1 + push!(vnv.varnames, vn) + push!(vnv.ranges, r_new) + append!(vnv.vals, val_vec) + push!(vnv.transforms, transform) + push!(vnv.is_unconstrained, false) + return nothing +end + +# TODO(mhauru) The gidset and num_produce arguments are used by the old Gibbs sampler. +# Remove this method as soon as possible. +function Base.push!(vnv::VarNamedVector, vn, val, dist, gidset, num_produce) + f = from_vec_transform(dist) + return push!(vnv, vn, val, f) +end + +""" + loosen_types!!(vnv::VarNamedVector{K,V,TVN,TVal,TTrans}, ::Type{KNew}, ::Type{TransNew}) + +Loosen the types of `vnv` to allow varname type `KNew` and transformation type `TransNew`. + +If `KNew` is a subtype of `K` and `TransNew` is a subtype of the element type of the +`TTrans` then this is a no-op and `vnv` is returned as is. Otherwise a new `VarNamedVector` +is returned with the same data but more abstract types, so that variables of type `KNew` and +transformations of type `TransNew` can be pushed to it. Some of the underlying storage is +shared between `vnv` and the return value, and thus mutating one may affect the other. + +# See also +[`tighten_types`](@ref) + +# Examples + +```jldoctest varnamedvector-loosen-types +julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!! + +julia> vnv = VarNamedVector(@varname(x) => [1.0]); + +julia> vnv_new = loosen_types!!(vnv, VarName{:x}, Real); + +julia> push!(vnv, @varname(y), Float32[2.0]) +ERROR: MethodError: Cannot `convert` an object of type + VarName{y,typeof(identity)} to an object of type + VarName{x,typeof(identity)} +[...] + +julia> vnv_loose = DynamicPPL.loosen_types!!(vnv, typeof(@varname(y)), Float32); + +julia> push!(vnv_loose, @varname(y), Float32[2.0]); vnv_loose # Passes without issues. +VarNamedVector{VarName{sym, typeof(identity)} where sym, Float64, Vector{VarName{sym, typeof(identity)} where sym}, Vector{Float64}, Vector{Any}}(OrderedDict{VarName{sym, typeof(identity)} where sym, Int64}(x => 1, y => 2), VarName{sym, typeof(identity)} where sym[x, y], UnitRange{Int64}[1:1, 2:2], [1.0, 2.0], Any[identity, identity], Bool[0, 0], OrderedDict{Int64, Int64}()) +""" +function loosen_types!!( + vnv::VarNamedVector, ::Type{KNew}, ::Type{TransNew} +) where {KNew,TransNew} + K = eltype(vnv.varnames) + Trans = eltype(vnv.transforms) + if KNew <: K && TransNew <: Trans + return vnv + else + vn_type = promote_type(K, KNew) + transform_type = promote_type(Trans, TransNew) + return VarNamedVector( + OrderedDict{vn_type,Int}(vnv.varname_to_index), + Vector{vn_type}(vnv.varnames), + vnv.ranges, + vnv.vals, + Vector{transform_type}(vnv.transforms), + vnv.is_unconstrained, + vnv.num_inactive, + ) + end +end + +""" + tighten_types(vnv::VarNamedVector) + +Return a copy of `vnv` with the most concrete types possible. + +For instance, if `vnv` has element type `Real`, but all the values are actually `Float64`s, +then `tighten_types(vnv)` will have element type `Float64`. + +# See also +[`loosen_types!!`](@ref) +""" +function tighten_types(vnv::VarNamedVector) + return VarNamedVector( + OrderedDict(vnv.varname_to_index...), + [vnv.varnames...], + copy(vnv.ranges), + [vnv.vals...], + [vnv.transforms...], + copy(vnv.is_unconstrained), + copy(vnv.num_inactive), + ) +end + +function BangBang.push!!( + vnv::VarNamedVector, vn::VarName, val, transform=from_vec_transform(val) +) + vnv = loosen_types!!(vnv, typeof(vn), typeof(transform)) + push!(vnv, vn, val, transform) + return vnv +end + +# TODO(mhauru) The gidset and num_produce arguments are used by the old Gibbs sampler. +# Remove this method as soon as possible. +function BangBang.push!!(vnv::VarNamedVector, vn, val, dist, gidset, num_produce) + f = from_vec_transform(dist) + return push!!(vnv, vn, val, f) +end + +""" + shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) + +Shifts the elements of `x` starting from index `start` by `n` to the right. +""" +function shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) + x[(start + n):end] = x[start:(end - n)] + return x +end + +""" + shift_subsequent_ranges_by!(vnv::VarNamedVector, idx::Int, n) + +Shifts the ranges of variables in `vnv` starting from index `idx` by `n`. +""" +function shift_subsequent_ranges_by!(vnv::VarNamedVector, idx::Int, n) + for i in (idx + 1):length(vnv.ranges) + vnv.ranges[i] = vnv.ranges[i] .+ n + end + return nothing +end + +""" + update!(vnv::VarNamedVector, vn::VarName, val[, transform]) + +Either add a new entry or update existing entry for `vn` in `vnv` with the value `val`. + +If `vn` does not exist in `vnv`, this is equivalent to [`push!`](@ref). + +By default `transform` is the one that converts the value to a vector, which is how it is +stored in `vnv`. +""" +function update!(vnv::VarNamedVector, vn::VarName, val, transform=from_vec_transform(val)) + if !haskey(vnv, vn) + # Here we just add a new entry. + return push!(vnv, vn, val, transform) + end + + # Here we update an existing entry. + val_vec = tovec(val) + idx = getidx(vnv, vn) + # Extract the old range. + r_old = getrange(vnv, idx) + start_old, end_old = first(r_old), last(r_old) + n_old = length(r_old) + # Compute the new range. + n_new = length(val_vec) + start_new = start_old + end_new = start_old + n_new - 1 + r_new = start_new:end_new + + #= + Suppose we currently have the following: + + | x | x | o | o | o | y | y | y | <- Current entries + + where 'O' denotes an inactive entry, and we're going to + update the variable `x` to be of size `k` instead of 2. + + We then have a few different scenarios: + 1. `k > 5`: All inactive entries become active + need to shift `y` to the right. + E.g. if `k = 7`, then + + | x | x | o | o | o | y | y | y | <- Current entries + | x | x | x | x | x | x | x | y | y | y | <- New entries + + 2. `k = 5`: All inactive entries become active. + Then + + | x | x | o | o | o | y | y | y | <- Current entries + | x | x | x | x | x | y | y | y | <- New entries + + 3. `k < 5`: Some inactive entries become active, some remain inactive. + E.g. if `k = 3`, then + + | x | x | o | o | o | y | y | y | <- Current entries + | x | x | x | o | o | y | y | y | <- New entries + + 4. `k = 2`: No inactive entries become active. + Then + + | x | x | o | o | o | y | y | y | <- Current entries + | x | x | o | o | o | y | y | y | <- New entries + + 5. `k < 2`: More entries become inactive. + E.g. if `k = 1`, then + + | x | x | o | o | o | y | y | y | <- Current entries + | x | o | o | o | o | y | y | y | <- New entries + =# + + # Compute the allocated space for `vn`. + had_inactive = haskey(vnv.num_inactive, idx) + n_allocated = had_inactive ? n_old + vnv.num_inactive[idx] : n_old + + if n_new > n_allocated + # Then we need to grow the underlying vector. + n_extra = n_new - n_allocated + # Allocate. + resize!(vnv.vals, length(vnv.vals) + n_extra) + # Shift current values. + shift_right!(vnv.vals, end_old + 1, n_extra) + # No more inactive entries. + had_inactive && delete!(vnv.num_inactive, idx) + # Update the ranges for all variables after this one. + shift_subsequent_ranges_by!(vnv, idx, n_extra) + elseif n_new == n_allocated + # => No more inactive entries. + had_inactive && delete!(vnv.num_inactive, idx) + else + # `n_new < n_allocated` + # => Need to update the number of inactive entries. + vnv.num_inactive[idx] = n_allocated - n_new + end + + # Update the range for this variable. + vnv.ranges[idx] = r_new + # Update the value. + vnv.vals[r_new] = val_vec + # Update the transform. + vnv.transforms[idx] = transform + + # TODO: Should we maybe sweep over inactive ranges and re-contiguify + # if the total number of inactive elements is "large" in some sense? + + return nothing +end + +function update!!(vnv::VarNamedVector, vn::VarName, val, transform=from_vec_transform(val)) + vnv = loosen_types!!(vnv, typeof(vn), typeof(transform)) + update!(vnv, vn, val, transform) + return vnv +end + +# set!! is the function defined in utils.jl that tries to do fancy stuff with optics when +# setting the value of a generic container using a VarName. We can bypass all that because +# VarNamedVector handles VarNames natively. +set!!(vnv::VarNamedVector, vn::VarName, val) = update!!(vnv, vn, val) + +function setval!(vnv::VarNamedVector, val, vn::VarName) + return setindex_raw!(vnv, tovec(val), vn) +end + +function recontiguify_ranges!(ranges::AbstractVector{<:AbstractRange}) + offset = 0 + for i in 1:length(ranges) + r_old = ranges[i] + ranges[i] = (offset + 1):(offset + length(r_old)) + offset += length(r_old) + end + + return ranges +end + +""" + contiguify!(vnv::VarNamedVector) + +Re-contiguify the underlying vector and shrink if possible. + +# Examples + +```jldoctest varnamedvector-contiguify +julia> using DynamicPPL: VarNamedVector, @varname, contiguify!, update!, has_inactive + +julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0, 3.0], @varname(y) => [3.0]); + +julia> update!(vnv, @varname(x), [23.0, 24.0]); + +julia> has_inactive(vnv) +true + +julia> length(vnv.vals) +4 + +julia> contiguify!(vnv); + +julia> has_inactive(vnv) +false + +julia> length(vnv.vals) +3 + +julia> vnv[@varname(x)] # All the values are still there. +2-element Vector{Float64}: + 23.0 + 24.0 +``` +""" +function contiguify!(vnv::VarNamedVector) + # Extract the re-contiguified values. + # NOTE: We need to do this before we update the ranges. + old_vals = copy(vnv.vals) + old_ranges = copy(vnv.ranges) + # And then we re-contiguify the ranges. + recontiguify_ranges!(vnv.ranges) + # Clear the inactive ranges. + empty!(vnv.num_inactive) + # Now we update the values. + for (old_range, new_range) in zip(old_ranges, vnv.ranges) + vnv.vals[new_range] = old_vals[old_range] + end + # And (potentially) shrink the underlying vector. + resize!(vnv.vals, vnv.ranges[end][end]) + # The rest should be left as is. + return vnv +end + +""" + group_by_symbol(vnv::VarNamedVector) + +Return a dictionary mapping symbols to `VarNamedVector`s with varnames containing that +symbol. + +# Examples + +```jldoctest varnamedvector-group-by-symbol +julia> using DynamicPPL: VarNamedVector, @varname, group_by_symbol + +julia> vnv = VarNamedVector(@varname(x) => [1.0], @varname(y) => [2.0], @varname(x[1]) => [3.0]); + +julia> d = group_by_symbol(vnv); + +julia> collect(keys(d)) +[Symbol("x"), Symbol("y")] + +julia> d[@varname(x)] == VarNamedVector(@varname(x) => [1.0], @varname(x[1]) => [3.0]) +true + +julia> d[@varname(y)] == VarNamedVector(@varname(y) => [2.0]) +true +""" +function group_by_symbol(vnv::VarNamedVector) + symbols = unique(map(getsym, vnv.varnames)) + nt_vals = map(s -> tighten_types(subset(vnv, [VarName(s)])), symbols) + return OrderedDict(zip(symbols, nt_vals)) +end + +""" + shift_index_left!(vnv::VarNamedVector, idx::Int) + +Shift the index `idx` to the left by one and update the relevant fields. + +This only affects `vnv.varname_to_index` and `vnv.num_inactive` and is only valid as a +helper function for [`shift_subsequent_indices_left!`](@ref). + +!!! warning + This does not check if index we're shifting to is already occupied. +""" +function shift_index_left!(vnv::VarNamedVector, idx::Int) + # Shift the index in the lookup table. + vn = vnv.varnames[idx] + vnv.varname_to_index[vn] = idx - 1 + # Shift the index in the inactive ranges. + if haskey(vnv.num_inactive, idx) + # Done in increasing order => don't need to worry about + # potentially shifting the same index twice. + vnv.num_inactive[idx - 1] = pop!(vnv.num_inactive, idx) + end +end + +""" + shift_subsequent_indices_left!(vnv::VarNamedVector, idx::Int) + +Shift the indices for all variables after `idx` to the left by one and update the relevant + fields. + +This only affects `vnv.varname_to_index` and `vnv.num_inactive` and is only valid as a +helper function for [`delete!`](@ref). +""" +function shift_subsequent_indices_left!(vnv::VarNamedVector, idx::Int) + # Shift the indices for all variables after `idx`. + for idx_to_shift in (idx + 1):length(vnv.varnames) + shift_index_left!(vnv, idx_to_shift) + end +end + +function Base.delete!(vnv::VarNamedVector, vn::VarName) + # Error if we don't have the variable. + !haskey(vnv, vn) && throw(ArgumentError("variable name $vn does not exist")) + + # Get the index of the variable. + idx = getidx(vnv, vn) + + # Delete the values. + r_start = first(getrange(vnv, idx)) + n_allocated = num_allocated(vnv, idx) + # NOTE: `deleteat!` also results in a `resize!` so we don't need to do that. + deleteat!(vnv.vals, r_start:(r_start + n_allocated - 1)) + + # Delete `vn` from the lookup table. + delete!(vnv.varname_to_index, vn) + + # Delete any inactive ranges corresponding to `vn`. + haskey(vnv.num_inactive, idx) && delete!(vnv.num_inactive, idx) + + # Re-adjust the indices for varnames occuring after `vn` so + # that they point to the correct indices after the deletions below. + shift_subsequent_indices_left!(vnv, idx) + + # Re-adjust the ranges for varnames occuring after `vn`. + shift_subsequent_ranges_by!(vnv, idx, -n_allocated) + + # Delete references from vector fields, thus shifting the indices of + # varnames occuring after `vn` by one to the left, as we adjusted for above. + deleteat!(vnv.varnames, idx) + deleteat!(vnv.ranges, idx) + deleteat!(vnv.transforms, idx) + + return vnv +end + +""" + values_as(vnv::VarNamedVector[, T]) + +Return the values/realizations in `vnv` as type `T`, if implemented. + +If no type `T` is provided, return values as stored in `vnv`. + +# Examples + +```jldoctest +julia> using DynamicPPL: VarNamedVector + +julia> vnv = VarNamedVector(@varname(x) => 1, @varname(y) => [2.0]); + +julia> values_as(vnv) == [1.0, 2.0] +true + +julia> values_as(vnv, Vector{Float32}) == Vector{Float32}([1.0, 2.0]) +true + +julia> values_as(vnv, OrderedDict) == OrderedDict(@varname(x) => 1.0, @varname(y) => [2.0]) +true + +julia> values_as(vnv, NamedTuple) == (x = 1.0, y = [2.0]) +true +``` +""" +values_as(vnv::VarNamedVector) = values_as(vnv, Vector) +values_as(vnv::VarNamedVector, ::Type{Vector}) = vnv[:] +function values_as(vnv::VarNamedVector, ::Type{Vector{T}}) where {T} + return convert(Vector{T}, values_as(vnv, Vector)) +end +function values_as(vnv::VarNamedVector, ::Type{NamedTuple}) + return NamedTuple(zip(map(Symbol, keys(vnv)), values(vnv))) +end +function values_as(vnv::VarNamedVector, ::Type{D}) where {D<:AbstractDict} + return ConstructionBase.constructorof(D)(pairs(vnv)) +end + +# TODO(mhauru) This is tricky to implement in the general case, and the below implementation +# only covers some simple cases. It's probably sufficient in most situations though. +function hasvalue(vnv::VarNamedVector, vn::VarName) + haskey(vnv, vn) && return true + any(subsumes(vn, k) for k in keys(vnv)) && return true + # Handle the easy case where the right symbol isn't even present. + !any(k -> getsym(k) == getsym(vn), keys(vnv)) && return false + + optic = getoptic(vn) + if optic isa Accessors.IndexLens || optic isa Accessors.ComposedOptic + # If vn is of the form @varname(somesymbol[someindex]), we check whether we store + # @varname(somesymbol) and can index into it with someindex. If we rather have a + # composed optic with the last part being an index lens, we do a similar check but + # stripping out the last index lens part. If these pass, the answer is definitely + # "yes". If not, we still don't know for sure. + # TODO(mhauru) What about casese where vnv stores both @varname(x) and + # @varname(x[1]) or @varname(x.a)? Those should probably be banned, but currently + # aren't. + head, tail = if optic isa Accessors.ComposedOptic + decomp_optic = Accessors.decompose(optic) + first(decomp_optic), Accessors.compose(decomp_optic[2:end]...) + else + optic, identity + end + parent_varname = VarName{getsym(vn)}(tail) + if haskey(vnv, parent_varname) + valvec = getindex(vnv, parent_varname) + return canview(head, valvec) + end + end + throw(ErrorException("hasvalue has not been fully implemented for this VarName: $(vn)")) +end + +# TODO(mhauru) Like hasvalue, this is only partially implemented. +function getvalue(vnv::VarNamedVector, vn::VarName) + !hasvalue(vnv, vn) && throw(KeyError(vn)) + haskey(vnv, vn) && getindex(vnv, vn) + + subsumed_keys = filter(k -> subsumes(vn, k), keys(vnv)) + if length(subsumed_keys) > 0 + # TODO(mhauru) What happens if getindex returns e.g. matrices, and we vcat them? + return mapreduce(k -> getindex(vnv, k), vcat, subsumed_keys) + end + + optic = getoptic(vn) + # See hasvalue for some comments on the logic of this if block. + if optic isa Accessors.IndexLens || optic isa Accessors.ComposedOptic + head, tail = if optic isa Accessors.ComposedOptic + decomp_optic = Accessors.decompose(optic) + first(decomp_optic), Accessors.compose(decomp_optic[2:end]...) + else + optic, identity + end + parent_varname = VarName{getsym(vn)}(tail) + valvec = getindex(vnv, parent_varname) + return head(valvec) + end + throw(ErrorException("getvalue has not been fully implemented for this VarName: $(vn)")) +end + +Base.get(vnv::VarNamedVector, vn::VarName) = getvalue(vnv, vn) diff --git a/src/varnamevector.jl b/src/varnamevector.jl deleted file mode 100644 index ed919a20c..000000000 --- a/src/varnamevector.jl +++ /dev/null @@ -1,756 +0,0 @@ -""" - VarNameVector - -A container that works like a `Vector` and an `OrderedDict` but is neither. - -# Fields -$(FIELDS) -""" -struct VarNameVector{ - K<:VarName,V,TVN<:AbstractVector{K},TVal<:AbstractVector{V},TTrans<:AbstractVector,MData -} - "mapping from the `VarName` to its integer index in `varnames`, `ranges` and `dists`" - varname_to_index::OrderedDict{K,Int} - - "vector of identifiers for the random variables, where `varnames[varname_to_index[vn]] == vn`" - varnames::TVN # AbstractVector{<:VarName} - - "vector of index ranges in `vals` corresponding to `varnames`; each `VarName` `vn` has a single index or a set of contiguous indices in `vals`" - ranges::Vector{UnitRange{Int}} - - "vector of values of all variables; the value(s) of `vn` is/are `vals[ranges[varname_to_index[vn]]]`" - vals::TVal # AbstractVector{<:Real} - - "vector of transformations whose inverse takes us back to the original space" - transforms::TTrans - - "specifies whether a variable is transformed or not " - is_transformed::BitVector - - "additional entries which are considered inactive" - num_inactive::OrderedDict{Int,Int} - - "metadata associated with the varnames" - metadata::MData -end - -function ==(vnv_left::VarNameVector, vnv_right::VarNameVector) - return vnv_left.varname_to_index == vnv_right.varname_to_index && - vnv_left.varnames == vnv_right.varnames && - vnv_left.ranges == vnv_right.ranges && - vnv_left.vals == vnv_right.vals && - vnv_left.transforms == vnv_right.transforms && - vnv_left.is_transformed == vnv_right.is_transformed && - vnv_left.num_inactive == vnv_right.num_inactive && - vnv_left.metadata == vnv_right.metadata -end - -function VarNameVector( - varname_to_index, - varnames, - ranges, - vals, - transforms, - is_transformed=fill!(BitVector(undef, length(varnames)), 0), -) - return VarNameVector( - varname_to_index, - varnames, - ranges, - vals, - transforms, - is_transformed, - OrderedDict{Int,Int}(), - nothing, - ) -end -# TODO: Do we need this? -function VarNameVector{K,V}() where {K,V} - return VarNameVector(OrderedDict{K,Int}(), K[], UnitRange{Int}[], V[], Any[]) -end - -istrans(vnv::VarNameVector, vn::VarName) = vnv.is_transformed[vnv.varname_to_index[vn]] -function settrans!(vnv::VarNameVector, val::Bool, vn::VarName) - return vnv.is_transformed[vnv.varname_to_index[vn]] = val -end - -VarNameVector() = VarNameVector{VarName,Real}() -VarNameVector(xs::Pair...) = VarNameVector(OrderedDict(xs...)) -VarNameVector(x::AbstractDict) = VarNameVector(keys(x), values(x)) -VarNameVector(varnames, vals) = VarNameVector(collect_maybe(varnames), collect_maybe(vals)) -function VarNameVector( - varnames::AbstractVector, vals::AbstractVector, transforms=map(from_vec_transform, vals) -) - # TODO: Check uniqueness of `varnames`? - - # Convert `vals` into a vector of vectors. - vals_vecs = map(tovec, vals) - - # TODO: Is this really the way to do this? - if !(eltype(varnames) <: VarName) - varnames = convert(Vector{VarName}, varnames) - end - varname_to_index = OrderedDict{eltype(varnames),Int}() - ranges = Vector{UnitRange{Int}}() - offset = 0 - for (i, (vn, x)) in enumerate(zip(varnames, vals_vecs)) - # Add the varname index. - push!(varname_to_index, vn => length(varname_to_index) + 1) - # Add the range. - r = (offset + 1):(offset + length(x)) - push!(ranges, r) - # Update the offset. - offset = r[end] - end - - return VarNameVector( - varname_to_index, varnames, ranges, reduce(vcat, vals_vecs), transforms - ) -end - -""" - replace_values(vnv::VarNameVector, vals::AbstractVector) - -Replace the values in `vnv` with `vals`. - -This is useful when we want to update the entire underlying vector of values -in one go or if we want to change the how the values are stored, e.g. alter the `eltype`. - -!!! warning - This replaces the raw underlying values, and so care should be taken when using this - function. For example, if `vnv` has any inactive entries, then the provided `vals` - should also contain the inactive entries to avoid unexpected behavior. - -# Example - -```jldoctest varnamevector-replace-values -julia> using DynamicPPL: VarNameVector, replace_values - -julia> vnv = VarNameVector(@varname(x) => [1.0]); - -julia> replace_values(vnv, [2.0])[@varname(x)] == [2.0] -true -``` - -This is also useful when we want to differentiate wrt. the values -using automatic differentiation, e.g. ForwardDiff.jl. - -```jldoctest varnamevector-replace-values -julia> using ForwardDiff: ForwardDiff - -julia> f(x) = sum(abs2, replace_values(vnv, x)[@varname(x)]) -f (generic function with 1 method) - -julia> ForwardDiff.gradient(f, [1.0]) -1-element Vector{Float64}: - 2.0 -``` -""" -replace_values(vnv::VarNameVector, vals) = Setfield.@set vnv.vals = vals - -# Some `VarNameVector` specific functions. -getidx(vnv::VarNameVector, vn::VarName) = vnv.varname_to_index[vn] - -getrange(vnv::VarNameVector, idx::Int) = vnv.ranges[idx] -getrange(vnv::VarNameVector, vn::VarName) = getrange(vnv, getidx(vnv, vn)) - -gettransform(vnv::VarNameVector, vn::VarName) = vnv.transforms[getidx(vnv, vn)] - -""" - has_inactive(vnv::VarNameVector) - -Returns `true` if `vnv` has inactive ranges. -""" -has_inactive(vnv::VarNameVector) = !isempty(vnv.num_inactive) - -""" - num_inactive(vnv::VarNameVector) - -Return the number of inactive entries in `vnv`. -""" -num_inactive(vnv::VarNameVector) = sum(values(vnv.num_inactive)) - -""" - num_inactive(vnv::VarNameVector, vn::VarName) - -Returns the number of inactive entries for `vn` in `vnv`. -""" -num_inactive(vnv::VarNameVector, vn::VarName) = num_inactive(vnv, getidx(vnv, vn)) -num_inactive(vnv::VarNameVector, idx::Int) = get(vnv.num_inactive, idx, 0) - -""" - num_allocated(vnv::VarNameVector) - -Returns the number of allocated entries in `vnv`. -""" -num_allocated(vnv::VarNameVector) = length(vnv.vals) - -""" - num_allocated(vnv::VarNameVector, vn::VarName) - -Returns the number of allocated entries for `vn` in `vnv`. -""" -num_allocated(vnv::VarNameVector, vn::VarName) = num_allocated(vnv, getidx(vnv, vn)) -function num_allocated(vnv::VarNameVector, idx::Int) - return length(getrange(vnv, idx)) + num_inactive(vnv, idx) -end - -# Basic array interface. -Base.eltype(vnv::VarNameVector) = eltype(vnv.vals) -Base.length(vnv::VarNameVector) = - if isempty(vnv.num_inactive) - length(vnv.vals) - else - sum(length, vnv.ranges) - end -Base.size(vnv::VarNameVector) = (length(vnv),) -Base.isempty(vnv::VarNameVector) = isempty(vnv.varnames) - -# TODO: We should probably remove this -Base.IndexStyle(::Type{<:VarNameVector}) = IndexLinear() - -# Dictionary interface. -Base.keys(vnv::VarNameVector) = vnv.varnames -Base.values(vnv::VarNameVector) = Iterators.map(Base.Fix1(getindex, vnv), vnv.varnames) -Base.pairs(vnv::VarNameVector) = (vn => vnv[vn] for vn in keys(vnv)) - -Base.haskey(vnv::VarNameVector, vn::VarName) = haskey(vnv.varname_to_index, vn) - -# `getindex` & `setindex!` -Base.getindex(vnv::VarNameVector, i::Int) = getindex_raw(vnv, i) -function Base.getindex(vnv::VarNameVector, vn::VarName) - x = getindex_raw(vnv, vn) - f = gettransform(vnv, vn) - return f(x) -end - -function find_range_from_sorted(ranges::AbstractVector{<:AbstractRange}, x) - # TODO: Assume `ranges` to be sorted and contiguous, and use `searchsortedfirst` - # for a more efficient approach. - range_idx = findfirst(Base.Fix1(∈, x), ranges) - - # If we're out of bounds, we raise an error. - if range_idx === nothing - throw(ArgumentError("Value $x is not in any of the ranges.")) - end - - return range_idx -end - -function adjusted_ranges(vnv::VarNameVector) - # Every range following inactive entries needs to be shifted. - offset = 0 - ranges_adj = similar(vnv.ranges) - for (idx, r) in enumerate(vnv.ranges) - # Remove the `offset` in `r` due to inactive entries. - ranges_adj[idx] = r .- offset - # Update `offset`. - offset += get(vnv.num_inactive, idx, 0) - end - - return ranges_adj -end - -function index_to_raw_index(vnv::VarNameVector, i::Int) - # If we don't have any inactive entries, there's nothing to do. - has_inactive(vnv) || return i - - # Get the adjusted ranges. - ranges_adj = adjusted_ranges(vnv) - # Determine the adjusted range that the index corresponds to. - r_idx = find_range_from_sorted(ranges_adj, i) - r = vnv.ranges[r_idx] - # Determine how much of the index `i` is used to get to this range. - i_used = r_idx == 1 ? 0 : sum(length, ranges_adj[1:(r_idx - 1)]) - # Use remainder to index into `r`. - i_remainder = i - i_used - return r[i_remainder] -end - -getindex_raw(vnv::VarNameVector, i::Int) = vnv.vals[index_to_raw_index(vnv, i)] -getindex_raw(vnv::VarNameVector, vn::VarName) = vnv.vals[getrange(vnv, vn)] - -# `getindex` for `Colon` -function Base.getindex(vnv::VarNameVector, ::Colon) - return if has_inactive(vnv) - mapreduce(Base.Fix1(getindex, vnv.vals), vcat, vnv.ranges) - else - vnv.vals - end -end - -function getindex_raw(vnv::VarNameVector, ::Colon) - return if has_inactive(vnv) - mapreduce(Base.Fix1(getindex_raw, vnv.vals), vcat, vnv.ranges) - else - vnv.vals - end -end - -# HACK: remove this as soon as possible. -Base.getindex(vnv::VarNameVector, spl::AbstractSampler) = vnv[:] - -Base.setindex!(vnv::VarNameVector, val, i::Int) = setindex_raw!(vnv, val, i) -function Base.setindex!(vnv::VarNameVector, val, vn::VarName) - f = inverse(gettransform(vnv, vn)) - return setindex_raw!(vnv, f(val), vn) -end - -setindex_raw!(vnv::VarNameVector, val, i::Int) = vnv.vals[index_to_raw_index(vnv, i)] = val -function setindex_raw!(vnv::VarNameVector, val::AbstractVector, vn::VarName) - return vnv.vals[getrange(vnv, vn)] = val -end - -# `empty!(!)` -function Base.empty!(vnv::VarNameVector) - # TODO: Or should the semantics be different, e.g. keeping `varnames`? - empty!(vnv.varname_to_index) - empty!(vnv.varnames) - empty!(vnv.ranges) - empty!(vnv.vals) - empty!(vnv.transforms) - empty!(vnv.num_inactive) - return nothing -end -BangBang.empty!!(vnv::VarNameVector) = (empty!(vnv); return vnv) - -function Base.merge(left_vnv::VarNameVector, right_vnv::VarNameVector) - # Return early if possible. - isempty(left_vnv) && return deepcopy(right_vnv) - isempty(right_vnv) && return deepcopy(left_vnv) - - # Determine varnames. - vns_left = left_vnv.varnames - vns_right = right_vnv.varnames - vns_both = union(vns_left, vns_right) - - # Determine `eltype` of `vals`. - T_left = eltype(left_vnv.vals) - T_right = eltype(right_vnv.vals) - T = promote_type(T_left, T_right) - # TODO: Is this necessary? - if !(T <: Real) - T = Real - end - - # Determine `eltype` of `varnames`. - V_left = eltype(left_vnv.varnames) - V_right = eltype(right_vnv.varnames) - V = promote_type(V_left, V_right) - if !(V <: VarName) - V = VarName - end - - # Determine `eltype` of `transforms`. - F_left = eltype(left_vnv.transforms) - F_right = eltype(right_vnv.transforms) - F = promote_type(F_left, F_right) - - # Allocate. - varnames_to_index = OrderedDict{V,Int}() - ranges = UnitRange{Int}[] - vals = T[] - transforms = F[] - is_transformed = BitVector(undef, length(vns_both)) - - # Range offset. - offset = 0 - - for (idx, vn) in enumerate(vns_both) - # Extract the necessary information from `left` or `right`. - if vn in vns_left && !(vn in vns_right) - # `vn` is only in `left`. - varnames_to_index[vn] = idx - val = getindex_raw(left_vnv, vn) - n = length(val) - r = (offset + 1):(offset + n) - f = gettransform(left_vnv, vn) - is_transformed[idx] = istrans(left_vnv, vn) - else - # `vn` is either in both or just `right`. - varnames_to_index[vn] = idx - val = getindex_raw(right_vnv, vn) - n = length(val) - r = (offset + 1):(offset + n) - f = gettransform(right_vnv, vn) - is_transformed[idx] = istrans(right_vnv, vn) - end - # Update. - append!(vals, val) - push!(ranges, r) - push!(transforms, f) - # Increment `offset`. - offset += n - end - - return VarNameVector(varnames_to_index, vns_both, ranges, vals, transforms) -end - -function subset(vnv::VarNameVector, vns::AbstractVector{<:VarName}) - # NOTE: This does not specialize types when possible. - vnv_new = similar(vnv) - # Return early if possible. - isempty(vnv) && return vnv_new - - for vn in vns - push!(vnv_new, vn, getindex_internal(vnv, vn), gettransform(vnv, vn)) - end - - return vnv_new -end - -# `similar` -similar_metadata(::Nothing) = nothing -similar_metadata(x::Union{AbstractArray,AbstractDict}) = similar(x) -function Base.similar(vnv::VarNameVector) - # NOTE: Whether or not we should empty the underlying containers or note - # is somewhat ambiguous. For example, `similar(vnv.varname_to_index)` will - # result in an empty `AbstractDict`, while the vectors, e.g. `vnv.ranges`, - # will result in non-empty vectors but with entries as `undef`. But it's - # much easier to write the rest of the code assuming that `undef` is not - # present, and so for now we empty the underlying containers, thus differing - # from the behavior of `similar` for `AbstractArray`s. - return VarNameVector( - similar(vnv.varname_to_index), - similar(vnv.varnames, 0), - similar(vnv.ranges, 0), - similar(vnv.vals, 0), - similar(vnv.transforms, 0), - BitVector(), - similar(vnv.num_inactive), - similar_metadata(vnv.metadata), - ) -end - -""" - is_contiguous(vnv::VarNameVector) - -Returns `true` if the underlying data of `vnv` is stored in a contiguous array. - -This is equivalent to negating [`has_inactive(vnv)`](@ref). -""" -is_contiguous(vnv::VarNameVector) = !has_inactive(vnv) - -function nextrange(vnv::VarNameVector, x) - # If `vnv` is empty, return immediately. - isempty(vnv) && return 1:length(x) - - # The offset will be the last range's end + its number of inactive entries. - vn_last = vnv.varnames[end] - idx = getidx(vnv, vn_last) - offset = last(getrange(vnv, idx)) + num_inactive(vnv, idx) - - return (offset + 1):(offset + length(x)) -end - -# `push!` and `push!!`: add a variable to the varname vector. -function Base.push!(vnv::VarNameVector, vn::VarName, val, transform=from_vec_transform(val)) - # Error if we already have the variable. - haskey(vnv, vn) && throw(ArgumentError("variable name $vn already exists")) - # NOTE: We need to compute the `nextrange` BEFORE we start mutating - # the underlying; otherwise we might get some strange behaviors. - val_vec = tovec(val) - r_new = nextrange(vnv, val_vec) - vnv.varname_to_index[vn] = length(vnv.varname_to_index) + 1 - push!(vnv.varnames, vn) - push!(vnv.ranges, r_new) - append!(vnv.vals, val_vec) - push!(vnv.transforms, transform) - push!(vnv.is_transformed, false) - return nothing -end - -""" - shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) - -Shifts the elements of `x` starting from index `start` by `n` to the right. -""" -function shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) - x[(start + n):end] = x[start:(end - n)] - return x -end - -""" - shift_subsequent_ranges_by!(vnv::VarNameVector, idx::Int, n) - -Shifts the ranges of variables in `vnv` starting from index `idx` by `n`. -""" -function shift_subsequent_ranges_by!(vnv::VarNameVector, idx::Int, n) - for i in (idx + 1):length(vnv.ranges) - vnv.ranges[i] = vnv.ranges[i] .+ n - end - return nothing -end - -# `update!` and `update!!`: update a variable in the varname vector. -""" - update!(vnv::VarNameVector, vn::VarName, val[, transform]) - -Either add a new entry or update existing entry for `vn` in `vnv` with the value `val`. - -If `vn` does not exist in `vnv`, this is equivalent to [`push!`](@ref). -""" -function update!(vnv::VarNameVector, vn::VarName, val, transform=from_vec_transform(val)) - if !haskey(vnv, vn) - # Here we just add a new entry. - return push!(vnv, vn, val, transform) - end - - # Here we update an existing entry. - val_vec = tovec(val) - idx = getidx(vnv, vn) - # Extract the old range. - r_old = getrange(vnv, idx) - start_old, end_old = first(r_old), last(r_old) - n_old = length(r_old) - # Compute the new range. - n_new = length(val_vec) - start_new = start_old - end_new = start_old + n_new - 1 - r_new = start_new:end_new - - #= - Suppose we currently have the following: - - | x | x | o | o | o | y | y | y | <- Current entries - - where 'O' denotes an inactive entry, and we're going to - update the variable `x` to be of size `k` instead of 2. - - We then have a few different scenarios: - 1. `k > 5`: All inactive entries become active + need to shift `y` to the right. - E.g. if `k = 7`, then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | x | x | x | x | x | y | y | y | <- New entries - - 2. `k = 5`: All inactive entries become active. - Then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | x | x | x | y | y | y | <- New entries - - 3. `k < 5`: Some inactive entries become active, some remain inactive. - E.g. if `k = 3`, then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | x | o | o | y | y | y | <- New entries - - 4. `k = 2`: No inactive entries become active. - Then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | o | o | o | y | y | y | <- New entries - - 5. `k < 2`: More entries become inactive. - E.g. if `k = 1`, then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | o | o | o | o | y | y | y | <- New entries - =# - - # Compute the allocated space for `vn`. - had_inactive = haskey(vnv.num_inactive, idx) - n_allocated = had_inactive ? n_old + vnv.num_inactive[idx] : n_old - - if n_new > n_allocated - # Then we need to grow the underlying vector. - n_extra = n_new - n_allocated - # Allocate. - resize!(vnv.vals, length(vnv.vals) + n_extra) - # Shift current values. - shift_right!(vnv.vals, end_old + 1, n_extra) - # No more inactive entries. - had_inactive && delete!(vnv.num_inactive, idx) - # Update the ranges for all variables after this one. - shift_subsequent_ranges_by!(vnv, idx, n_extra) - elseif n_new == n_allocated - # => No more inactive entries. - had_inactive && delete!(vnv.num_inactive, idx) - else - # `n_new < n_allocated` - # => Need to update the number of inactive entries. - vnv.num_inactive[idx] = n_allocated - n_new - end - - # Update the range for this variable. - vnv.ranges[idx] = r_new - # Update the value. - vnv.vals[r_new] = val_vec - # Update the transform. - vnv.transforms[idx] = transform - - # TODO: Should we maybe sweep over inactive ranges and re-contiguify - # if we the total number of inactive elements is "large" in some sense? - - return nothing -end - -function recontiguify_ranges!(ranges::AbstractVector{<:AbstractRange}) - offset = 0 - for i in 1:length(ranges) - r_old = ranges[i] - ranges[i] = (offset + 1):(offset + length(r_old)) - offset += length(r_old) - end - - return ranges -end - -""" - contiguify!(vnv::VarNameVector) - -Re-contiguify the underlying vector and shrink if possible. -""" -function contiguify!(vnv::VarNameVector) - # Extract the re-contiguified values. - # NOTE: We need to do this before we update the ranges. - vals = vnv[:] - # And then we re-contiguify the ranges. - recontiguify_ranges!(vnv.ranges) - # Clear the inactive ranges. - empty!(vnv.num_inactive) - # Now we update the values. - for (i, r) in enumerate(vnv.ranges) - vnv.vals[r] = vals[r] - end - # And (potentially) shrink the underlying vector. - resize!(vnv.vals, vnv.ranges[end][end]) - # The rest should be left as is. - return vnv -end - -""" - group_by_symbol(vnv::VarNameVector) - -Return a dictionary mapping symbols to `VarNameVector`s with -varnames containing that symbol. -""" -function group_by_symbol(vnv::VarNameVector) - # Group varnames in `vnv` by the symbol. - d = OrderedDict{Symbol,Vector{VarName}}() - for vn in vnv.varnames - push!(get!(d, getsym(vn), Vector{VarName}()), vn) - end - - # Create a `NamedTuple` from the grouped varnames. - nt_vals = map(values(d)) do varnames - # TODO: Do we need to specialize the inputs here? - VarNameVector( - map(identity, varnames), - map(Base.Fix1(getindex, vnv), varnames), - map(Base.Fix1(gettransform, vnv), varnames), - ) - end - - return OrderedDict(zip(keys(d), nt_vals)) -end - -""" - shift_index_left!(vnv::VarNameVector, idx::Int) - -Shift the index `idx` to the left by one and update the relevant fields. - -!!! warning - This does not check if index we're shifting to is already occupied. -""" -function shift_index_left!(vnv::VarNameVector, idx::Int) - # Shift the index in the lookup table. - vn = vnv.varnames[idx] - vnv.varname_to_index[vn] = idx - 1 - # Shift the index in the inactive ranges. - if haskey(vnv.num_inactive, idx) - # Done in increasing order => don't need to worry about - # potentially shifting the same index twice. - vnv.num_inactive[idx - 1] = pop!(vnv.num_inactive, idx) - end -end - -""" - shift_subsequent_indices_left!(vnv::VarNameVector, idx::Int) - -Shift the indices for all variables after `idx` to the left by one and update -the relevant fields. - -This just -""" -function shift_subsequent_indices_left!(vnv::VarNameVector, idx::Int) - # Shift the indices for all variables after `idx`. - for idx_to_shift in (idx + 1):length(vnv.varnames) - shift_index_left!(vnv, idx_to_shift) - end -end - -function Base.delete!(vnv::VarNameVector, vn::VarName) - # Error if we don't have the variable. - !haskey(vnv, vn) && throw(ArgumentError("variable name $vn does not exist")) - - # Get the index of the variable. - idx = getidx(vnv, vn) - - # Delete the values. - r_start = first(getrange(vnv, idx)) - n_allocated = num_allocated(vnv, idx) - # NOTE: `deleteat!` also results in a `resize!` so we don't need to do that. - deleteat!(vnv.vals, r_start:(r_start + n_allocated - 1)) - - # Delete `vn` from the lookup table. - delete!(vnv.varname_to_index, vn) - - # Delete any inactive ranges corresponding to `vn`. - haskey(vnv.num_inactive, idx) && delete!(vnv.num_inactive, idx) - - # Re-adjust the indices for varnames occuring after `vn` so - # that they point to the correct indices after the deletions below. - shift_subsequent_indices_left!(vnv, idx) - - # Re-adjust the ranges for varnames occuring after `vn`. - shift_subsequent_ranges_by!(vnv, idx, -n_allocated) - - # Delete references from vector fields, thus shifting the indices of - # varnames occuring after `vn` by one to the left, as we adjusted for above. - deleteat!(vnv.varnames, idx) - deleteat!(vnv.ranges, idx) - deleteat!(vnv.transforms, idx) - - return vnv -end - -""" - values_as(vnv::VarNameVector[, T]) - -Return the values/realizations in `vnv` as type `T`, if implemented. - -If no type `T` is provided, return values as stored in `vnv`. - -# Examples - -```jldoctest -julia> using DynamicPPL: VarNameVector - -julia> vnv = VarNameVector(@varname(x) => 1, @varname(y) => [2.0]); - -julia> values_as(vnv) == [1.0, 2.0] -true - -julia> values_as(vnv, Vector{Float32}) == Vector{Float32}([1.0, 2.0]) -true - -julia> values_as(vnv, OrderedDict) == OrderedDict(@varname(x) => 1.0, @varname(y) => [2.0]) -true - -julia> values_as(vnv, NamedTuple) == (x = 1.0, y = [2.0]) -true -``` -""" -values_as(vnv::VarNameVector) = values_as(vnv, Vector) -values_as(vnv::VarNameVector, ::Type{Vector}) = vnv[:] -function values_as(vnv::VarNameVector, ::Type{Vector{T}}) where {T} - return convert(Vector{T}, values_as(vnv, Vector)) -end -function values_as(vnv::VarNameVector, ::Type{NamedTuple}) - return NamedTuple(zip(map(Symbol, keys(vnv)), values(vnv))) -end -function values_as(vnv::VarNameVector, ::Type{D}) where {D<:AbstractDict} - return ConstructionBase.constructorof(D)(pairs(vnv)) -end diff --git a/test/compiler.jl b/test/compiler.jl index f1f06eabe..f2d7e5852 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -309,11 +309,11 @@ module Issue537 end vi2 = VarInfo(f2()) vi3 = VarInfo(f3()) @test haskey(vi1.metadata, :y) - @test vi1.metadata.y.vns[1] == @varname(y) + @test first(Base.keys(vi1.metadata.y)) == @varname(y) @test haskey(vi2.metadata, :y) - @test vi2.metadata.y.vns[1] == @varname(y[2][:, 1]) + @test first(Base.keys(vi2.metadata.y)) == @varname(y[2][:, 1]) @test haskey(vi3.metadata, :y) - @test vi3.metadata.y.vns[1] == @varname(y[1]) + @test first(Base.keys(vi3.metadata.y)) == @varname(y[1]) # Conditioning f1_c = f1() | (y=1,) diff --git a/test/model.jl b/test/model.jl index 60a8d2461..eaad848f1 100644 --- a/test/model.jl +++ b/test/model.jl @@ -122,7 +122,6 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true @test logjoints[i] ≈ DynamicPPL.TestUtils.logjoint_true(model, samples[:s], samples[:m]) end - println("\n model $(model) passed !!! \n") end end @@ -200,10 +199,10 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true @test !any(map(x -> x isa DynamicPPL.AbstractVarInfo, call_retval)) end - @testset "Dynamic constraints" begin + @testset "Dynamic constraints, Metadata" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() - vi = VarInfo(model) spl = SampleFromPrior() + vi = VarInfo(model, spl, DefaultContext(), DynamicPPL.Metadata) link!!(vi, spl, model) for i in 1:10 @@ -216,6 +215,18 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end end + @testset "Dynamic constraints, VectorVarInfo" begin + model = DynamicPPL.TestUtils.demo_dynamic_constraint() + vi = VarInfo(model) + vi = link!!(vi, model) + + for i in 1:10 + # Sample with large variations. + vi[@varname(m)] = randn() * 10 + model(vi) + end + end + @testset "rand" begin model = gdemo_default @@ -324,7 +335,6 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true chain = MCMCChains.Chains( permutedims(stack(vals)), syms; info=(varname_to_symbol=vns_to_syms,) ) - display(chain) # Test! results = generated_quantities(model, chain) @@ -345,7 +355,6 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true vcat(syms, [:y]); info=(varname_to_symbol=vns_to_syms_with_extra,), ) - display(chain_with_extra) # Test! results = generated_quantities(model, chain_with_extra) for (x_true, result) in zip(xs, results) @@ -358,6 +367,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true models_to_test = [ DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) ] + context = DefaultContext() @testset "$(model.f)" for model in models_to_test vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -366,18 +376,16 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - @test ( - @inferred(DynamicPPL.evaluate!!(model, varinfo, DefaultContext())); + @test begin + @inferred(DynamicPPL.evaluate!!(model, varinfo, context)) true - ) + end varinfo_linked = DynamicPPL.link(varinfo, model) - @test ( - @inferred( - DynamicPPL.evaluate!!(model, varinfo_linked, DefaultContext()) - ); + @test begin + @inferred(DynamicPPL.evaluate!!(model, varinfo_linked, context)) true - ) + end end end end diff --git a/test/runtests.jl b/test/runtests.jl index 806ab8223..9596067eb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -40,7 +40,7 @@ include("test_util.jl") @testset "interface" begin include("utils.jl") include("compiler.jl") - include("varnamevector.jl") + include("varnamedvector.jl") include("varinfo.jl") include("simple_varinfo.jl") include("model.jl") diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 74dcdd842..2684774bf 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -56,13 +56,41 @@ @test !haskey(svi, @varname(m.a[2])) @test !haskey(svi, @varname(m.a.b)) end + + @testset "VarNamedVector" begin + svi = SimpleVarInfo(push!!(VarNamedVector(), @varname(m), 1.0)) + @test getlogp(svi) == 0.0 + @test haskey(svi, @varname(m)) + @test !haskey(svi, @varname(m[1])) + + svi = SimpleVarInfo(push!!(VarNamedVector(), @varname(m), [1.0])) + @test getlogp(svi) == 0.0 + @test haskey(svi, @varname(m)) + @test haskey(svi, @varname(m[1])) + @test !haskey(svi, @varname(m[2])) + @test svi[@varname(m)][1] == svi[@varname(m[1])] + + svi = SimpleVarInfo(push!!(VarNamedVector(), @varname(m.a), [1.0])) + @test haskey(svi, @varname(m)) + @test haskey(svi, @varname(m.a)) + @test haskey(svi, @varname(m.a[1])) + @test !haskey(svi, @varname(m.a[2])) + @test !haskey(svi, @varname(m.a.b)) + # The implementation of haskey and getvalue fo VarNamedVector is incomplete, the + # next test is here to remind of us that. + svi = SimpleVarInfo(push!!(VarNamedVector(), @varname(m.a.b), [1.0])) + @test_broken (svi[@varname(m.a.b.c.d)]; true) + end end @testset "link!! & invlink!! on $(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) @testset "$(typeof(vi))" for vi in ( - SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), VarInfo(model) + SimpleVarInfo(Dict()), + SimpleVarInfo(values_constrained), + SimpleVarInfo(VarNamedVector()), + VarInfo(model), ) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) @@ -115,12 +143,19 @@ # to see whether this is the case. svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) svi_dict = SimpleVarInfo(VarInfo(model), Dict) + vnv = VarNamedVector() + for (k, v) in pairs(DynamicPPL.TestUtils.rand_prior_true(model)) + vnv = push!!(vnv, VarName{k}(), v) + end + svi_vnv = SimpleVarInfo(vnv) @testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in ( svi_nt, svi_dict, + svi_vnv, DynamicPPL.settrans!!(svi_nt, true), DynamicPPL.settrans!!(svi_dict, true), + DynamicPPL.settrans!!(svi_vnv, true), ) # RandOM seed is set in each `@testset`, so we need to sample # a new realization for `m` here. @@ -195,36 +230,39 @@ model = DynamicPPL.TestUtils.demo_dynamic_constraint() # Initialize. - svi = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi = last(DynamicPPL.evaluate!!(model, svi, SamplingContext())) - - # Sample with large variations in unconstrained space. - for i in 1:10 - for vn in keys(svi) - svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) - end - retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext()) - @test retval.m == svi[@varname(m)] # `m` is unconstrained - @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` + svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) + svi_nt = last(DynamicPPL.evaluate!!(model, svi_nt, SamplingContext())) + svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(VarNamedVector()), true) + svi_vnv = last(DynamicPPL.evaluate!!(model, svi_vnv, SamplingContext())) + + for svi in (svi_nt, svi_vnv) + # Sample with large variations in unconstrained space. + for i in 1:10 + for vn in keys(svi) + svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) + end + retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext()) + @test retval.m == svi[@varname(m)] # `m` is unconstrained + @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` + + retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, retval.m, retval.x + ) - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, retval.m, retval.x - ) + # Realizations from model should all be equal to the unconstrained realization. + for vn in DynamicPPL.TestUtils.varnames(model) + @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 + end - # Realizations from model should all be equal to the unconstrained realization. - for vn in DynamicPPL.TestUtils.varnames(model) - @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 + # `getlogp` should be equal to the logjoint with log-absdet-jac correction. + lp = getlogp(svi) + @test lp ≈ lp_true end - - # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogp(svi) - @test lp ≈ lp_true end end @testset "Static transformation" begin model = DynamicPPL.TestUtils.demo_static_transformation() - priors = extract_priors(model) varinfos = DynamicPPL.TestUtils.setup_varinfos( model, DynamicPPL.TestUtils.rand_prior_true(model), [@varname(s), @varname(m)] diff --git a/test/test_util.jl b/test/test_util.jl index 021da2598..0c7949e48 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -85,12 +85,14 @@ Return string representing a short description of `vi`. short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = "threadsafe($(short_varinfo_name(vi.varinfo)))" function short_varinfo_name(vi::TypedVarInfo) - DynamicPPL.has_varnamevector(vi) && return "TypedVarInfo with VarNameVector" + DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" return "TypedVarInfo" end short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" +short_varinfo_name(::VectorVarInfo) = "VectorVarInfo" short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" +short_varinfo_name(::SimpleVarInfo{<:VarNamedVector}) = "SimpleVarInfo{<:VarNamedVector}" # convenient functions for testing model.jl # function to modify the representation of values based on their length diff --git a/test/varinfo.jl b/test/varinfo.jl index 6750db416..a00935279 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -19,7 +19,7 @@ struct MySAlg end DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @testset "varinfo.jl" begin - @testset "TypedVarInfo" begin + @testset "TypedVarInfo with Metadata" begin @model gdemo(x, y) = begin s ~ InverseGamma(2, 3) m ~ truncated(Normal(0.0, sqrt(s)), 0.0, 2.0) @@ -28,7 +28,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end model = gdemo(1.0, 2.0) - vi = VarInfo() + vi = VarInfo(DynamicPPL.Metadata()) model(vi, SampleFromUniform()) tvi = TypedVarInfo(vi) @@ -51,6 +51,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end end end + @testset "Base" begin # Test Base functions: # string, Symbol, ==, hash, in, keys, haskey, isempty, push!!, empty!!, @@ -120,6 +121,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) test_base!!(TypedVarInfo(vi)) test_base!!(SimpleVarInfo()) test_base!!(SimpleVarInfo(Dict())) + test_base!!(SimpleVarInfo(VarNamedVector())) end @testset "flags" begin # Test flag setting: @@ -141,12 +143,12 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) unset_flag!(vi, vn_x, "del") @test !is_flagged(vi, vn_x, "del") end - vi = VarInfo() + vi = VarInfo(DynamicPPL.Metadata()) test_varinfo!(vi) test_varinfo!(empty!!(TypedVarInfo(vi))) end @testset "setgid!" begin - vi = VarInfo() + vi = VarInfo(DynamicPPL.Metadata()) meta = vi.metadata vn = @varname x dist = Normal(0, 1) @@ -196,8 +198,12 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) m_vns = model == model_uv ? [@varname(m[i]) for i in 1:5] : @varname(m) s_vns = @varname(s) - vi_typed = VarInfo(model) - vi_untyped = VarInfo() + # TODO(mhauru) Should add similar tests for VarNamedVector. These ones only apply + # to Metadata. + vi_typed = VarInfo( + model, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata + ) + vi_untyped = VarInfo(DynamicPPL.Metadata()) model(vi_untyped, SampleFromPrior()) for vi in [vi_untyped, vi_typed] @@ -338,6 +344,14 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + + ## `SimpleVarInfo{<:VarNamedVector}` + vi = DynamicPPL.settrans!!(SimpleVarInfo(VarNamedVector()), true) + # Sample in unconstrained space. + vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) + @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) end @testset "values_as" begin @@ -409,8 +423,8 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) continue end - if DynamicPPL.has_varnamevector(varinfo) && mutating - # NOTE: Can't handle mutating `link!` and `invlink!` `VarNameVector`. + if DynamicPPL.has_varnamedvector(varinfo) && mutating + # NOTE: Can't handle mutating `link!` and `invlink!` `VarNamedVector`. @test_broken false continue end @@ -642,6 +656,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) varinfo_left = VarInfo(model_left) varinfo_right = VarInfo(model_right) + varinfo_right = DynamicPPL.settrans!!(varinfo_right, true, @varname(x)) varinfo_merged = merge(varinfo_left, varinfo_right) vns = [@varname(x), @varname(y), @varname(z)] @@ -649,13 +664,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) # Right has precedence. @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] - @test DynamicPPL.getdist(varinfo_merged, @varname(x)) isa Normal + @test DynamicPPL.istrans(varinfo_merged, @varname(x)) end end @testset "VarInfo with selectors" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - varinfo = VarInfo(model) + varinfo = VarInfo( + model, + DynamicPPL.SampleFromPrior(), + DynamicPPL.DefaultContext(), + DynamicPPL.Metadata, + ) selector = DynamicPPL.Selector() spl = Sampler(MySAlg(), model, selector) diff --git a/test/varnamevector.jl b/test/varnamedvector.jl similarity index 89% rename from test/varnamevector.jl rename to test/varnamedvector.jl index 6b7acfbeb..604adc521 100644 --- a/test/varnamevector.jl +++ b/test/varnamedvector.jl @@ -7,7 +7,7 @@ decrease_size_for_test(x::Real) = x decrease_size_for_test(x::AbstractVector) = first(x) decrease_size_for_test(x::AbstractArray) = first(eachslice(x; dims=1)) -function need_varnames_relaxation(vnv::VarNameVector, vn::VarName, val) +function need_varnames_relaxation(vnv::VarNamedVector, vn::VarName, val) if isconcretetype(eltype(vnv.varnames)) # If the container is concrete, we need to make sure that the varname types match. # E.g. if `vnv.varnames` has `eltype` `VarName{:x, IndexLens{Tuple{Int64}}}` then @@ -20,22 +20,22 @@ function need_varnames_relaxation(vnv::VarNameVector, vn::VarName, val) return false end -function need_varnames_relaxation(vnv::VarNameVector, vns, vals) +function need_varnames_relaxation(vnv::VarNamedVector, vns, vals) return any(need_varnames_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) end -function need_values_relaxation(vnv::VarNameVector, vn::VarName, val) +function need_values_relaxation(vnv::VarNamedVector, vn::VarName, val) if isconcretetype(eltype(vnv.vals)) return promote_type(eltype(vnv.vals), eltype(val)) != eltype(vnv.vals) end return false end -function need_values_relaxation(vnv::VarNameVector, vns, vals) +function need_values_relaxation(vnv::VarNamedVector, vns, vals) return any(need_values_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) end -function need_transforms_relaxation(vnv::VarNameVector, vn::VarName, val) +function need_transforms_relaxation(vnv::VarNamedVector, vn::VarName, val) return if isconcretetype(eltype(vnv.transforms)) # If the container is concrete, we need to make sure that the sizes match. # => If the sizes don't match, we need to relax the container type. @@ -50,13 +50,13 @@ function need_transforms_relaxation(vnv::VarNameVector, vn::VarName, val) false end end -function need_transforms_relaxation(vnv::VarNameVector, vns, vals) +function need_transforms_relaxation(vnv::VarNamedVector, vns, vals) return any(need_transforms_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) end """ - relax_container_types(vnv::VarNameVector, vn::VarName, val) - relax_container_types(vnv::VarNameVector, vns, val) + relax_container_types(vnv::VarNamedVector, vn::VarName, val) + relax_container_types(vnv::VarNamedVector, vns, val) Relax the container types of `vnv` if necessary to accommodate `vn` and `val`. @@ -74,10 +74,10 @@ Similarly: transformations type in `vnv`, then the underlying transformation type will be changed to `Any`. """ -function relax_container_types(vnv::VarNameVector, vn::VarName, val) +function relax_container_types(vnv::VarNamedVector, vn::VarName, val) return relax_container_types(vnv, [vn], [val]) end -function relax_container_types(vnv::VarNameVector, vns, vals) +function relax_container_types(vnv::VarNamedVector, vns, vals) if need_varnames_relaxation(vnv, vns, vals) varname_to_index_new = convert(OrderedDict{VarName,Int}, vnv.varname_to_index) varnames_new = convert(Vector{VarName}, vnv.varnames) @@ -98,31 +98,30 @@ function relax_container_types(vnv::VarNameVector, vns, vals) vnv.vals end - return VarNameVector( + return VarNamedVector( varname_to_index_new, varnames_new, vnv.ranges, vals_new, transforms_new, - vnv.is_transformed, + vnv.is_unconstrained, vnv.num_inactive, - vnv.metadata, ) end -@testset "VarNameVector" begin - # Need to test element-related operations: +@testset "VarNamedVector" begin + # Test element-related operations: # - `getindex` # - `setindex!` # - `push!` # - `update!` # - # And these should all be tested for different types of values: + # And these are all be tested for different types of values: # - scalar # - vector # - matrix - # Need to test operations on `VarNameVector`: + # Test operations on `VarNamedVector`: # - `empty!` # - `iterate` # - `convert` to @@ -143,12 +142,12 @@ end @testset "constructor: no args" begin # Empty. - vnv = VarNameVector() + vnv = VarNamedVector() @test isempty(vnv) @test eltype(vnv) == Real # Empty with types. - vnv = VarNameVector{VarName,Float64}() + vnv = VarNamedVector{VarName,Float64}() @test isempty(vnv) @test eltype(vnv) == Float64 end @@ -157,17 +156,17 @@ end @testset "$(vn_left) and $(vn_right)" for (vn_left, vn_right) in test_varnames_iter val_left = test_pairs[vn_left] val_right = test_pairs[vn_right] - vnv_base = VarNameVector([vn_left, vn_right], [val_left, val_right]) + vnv_base = VarNamedVector([vn_left, vn_right], [val_left, val_right]) # We'll need the transformations later. - # TODO: Should we test other transformations than just `FromVec`? + # TODO: Should we test other transformations than just `ReshapeTransform`? from_vec_left = DynamicPPL.from_vec_transform(val_left) from_vec_right = DynamicPPL.from_vec_transform(val_right) to_vec_left = inverse(from_vec_left) to_vec_right = inverse(from_vec_right) # Compare to alternative constructors. - vnv_from_dict = VarNameVector( + vnv_from_dict = VarNamedVector( OrderedDict(vn_left => val_left, vn_right => val_right) ) @test vnv_base == vnv_from_dict @@ -302,6 +301,7 @@ end end end end + @testset "update!" begin vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) @testset "$vn" for vn in test_vns @@ -375,7 +375,7 @@ end @testset "deterministic" begin n = 5 vn = @varname(x) - vnv = VarNameVector(OrderedDict(vn => [true])) + vnv = VarNamedVector(OrderedDict(vn => [true])) @test !DynamicPPL.has_inactive(vnv) # Growing should not create inactive ranges. for i in 1:n @@ -401,7 +401,7 @@ end @testset "random" begin n = 5 vn = @varname(x) - vnv = VarNameVector(OrderedDict(vn => [true])) + vnv = VarNamedVector(OrderedDict(vn => [true])) @test !DynamicPPL.has_inactive(vnv) # Insert a bunch of random-length vectors. @@ -421,9 +421,23 @@ end end end end + + @testset "subset" begin + vnv = VarNamedVector(test_pairs) + @test subset(vnv, test_vns) == vnv + @test subset(vnv, VarName[]) == VarNamedVector() + @test merge(subset(vnv, test_vns[1:3]), subset(vnv, test_vns[4:end])) == vnv + + # Test that subset preseres transformations and unconstrainedness. + vn = @varname(t[1]) + vns = vcat(test_vns, [vn]) + push!(vnv, vn, 2.0, x -> x^2) + vnv.is_unconstrained[vnv.varname_to_index[vn]] = true + @test subset(vnv, vns) == vnv + end end -@testset "VarInfo + VarNameVector" begin +@testset "VarInfo + VarNamedVector" begin models = DynamicPPL.TestUtils.DEMO_MODELS @testset "$(model.f)" for model in models # NOTE: Need to set random seed explicitly to avoid using the same seed @@ -435,8 +449,8 @@ end varinfos = DynamicPPL.TestUtils.setup_varinfos( model, value_true, varnames; include_threadsafe=false ) - # Filter out those which are not based on `VarNameVector`. - varinfos = filter(DynamicPPL.has_varnamevector, varinfos) + # Filter out those which are not based on `VarNamedVector`. + varinfos = filter(DynamicPPL.has_varnamedvector, varinfos) # Get the true log joint. logp_true = DynamicPPL.TestUtils.logjoint_true(model, value_true...)