Skip to content

Commit

Permalink
subset and merge for VarInfo (clean version) (#544)
Browse files Browse the repository at this point in the history
* added `subset` which can extract a subset of the varinfo

* added testing of `subset` for `VarInfo`

* formatting

* added implementation of `merge` for `VarInfo` and tests for it

* more tests

* formatting

* improved merge_metadata for NamedTuple inputs

* added proper handling of the `vals` in `subset`

* added docs for `subset` and `merge`

* added `subset` and `merge` to documentation

* formatting

* made merge and subset part of the AbstractVarInfo interface

* added implementations `subset` and `merge` for `SimpleVarInfo`

* follow standard merge semantics where the right one takes precedence

* added proper testing of merge and subset for SimpleVarInfo too

* forgotten inclusion in previous commit

* Update src/simple_varinfo.jl

Co-authored-by: David Widmann <[email protected]>

* remove two-argument impl of merge

* formatting

* forgot to add more formatting

* removed 2-arg version of merge for abstract varinfo in favour of 3-arg version

* allow inclusion of threadsafe varinfo in setup_varinfos

* more tests for thread safe varinfo

* bugfixes for link and invlink methods when using thread safe varinfo

* attempt at fixing docs

* fixed missing test coverage

* formatting

---------

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
torfjelde and devmotion authored Oct 19, 2023
1 parent 927799f commit efd9da3
Show file tree
Hide file tree
Showing 9 changed files with 716 additions and 12 deletions.
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ DynamicPPL.reconstruct
#### Utils

```@docs
Base.merge(::AbstractVarInfo)
DynamicPPL.subset
DynamicPPL.unflatten
DynamicPPL.tonamedtuple
DynamicPPL.varname_leaves
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ export AbstractVarInfo,
SimpleVarInfo,
push!!,
empty!!,
subset,
getlogp,
setlogp!!,
acclogp!!,
Expand Down
161 changes: 161 additions & 0 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,27 @@ struct StaticTransformation{F} <: AbstractTransformation
bijector::F
end

"""
merge_transformations(transformation_left, transformation_right)
Merge two transformations.
The main use of this is in [`merge(::AbstractVarInfo, ::AbstractVarInfo)`](@ref).
"""
function merge_transformations(::NoTransformation, ::NoTransformation)
return NoTransformation()
end
function merge_transformations(::DynamicTransformation, ::DynamicTransformation)
return DynamicTransformation()
end
function merge_transformations(left::StaticTransformation, right::StaticTransformation)
return StaticTransformation(merge_bijectors(left.bijector, right.bijector))
end

function merge_bijectors(left::Bijectors.NamedTransform, right::Bijectors.NamedTransform)
return Bijectors.NamedTransform(merge_bijector(left.bs, right.bs))
end

"""
default_transformation(model::Model[, vi::AbstractVarInfo])
Expand Down Expand Up @@ -337,6 +358,146 @@ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromP
return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)}))
end

# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert
# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which
# might result in a `Vector{Any}`.
"""
subset(varinfo::AbstractVarInfo, vns::AbstractVector{<:VarName})
Subset a `varinfo` to only contain the variables `vns`.
!!! warning
The ordering of the variables in the resulting `varinfo` is _not_
guaranteed to follow the ordering of the variables in `varinfo`.
Hence care must be taken, in particular when used in conjunction with
other methods which uses the vector-representation of the `varinfo`,
e.g. `getindex(varinfo, sampler)`.
# Examples
```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL)
julia> @model function demo()
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
x = Vector{Float64}(undef, 2)
x[1] ~ Normal(m, sqrt(s))
x[2] ~ Normal(m, sqrt(s))
end
demo (generic function with 2 methods)
julia> model = demo();
julia> varinfo = VarInfo(model);
julia> keys(varinfo)
4-element Vector{VarName}:
s
m
x[1]
x[2]
julia> for (i, vn) in enumerate(keys(varinfo))
varinfo[vn] = i
end
julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
1.0
2.0
3.0
4.0
julia> # Extract one with only `m`.
varinfo_subset1 = subset(varinfo, [@varname(m),]);
julia> keys(varinfo_subset1)
1-element Vector{VarName{:m, Setfield.IdentityLens}}:
m
julia> varinfo_subset1[@varname(m)]
2.0
julia> # Extract one with both `s` and `x[2]`.
varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]);
julia> keys(varinfo_subset2)
2-element Vector{VarName}:
s
x[2]
julia> varinfo_subset2[[@varname(s), @varname(x[2])]]
2-element Vector{Float64}:
1.0
4.0
```
`subset` is particularly useful when combined with [`merge(varinfo::AbstractVarInfo)`](@ref)
```jldoctest varinfo-subset
julia> # Merge the two.
varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2);
julia> keys(varinfo_subset_merged)
3-element Vector{VarName}:
m
s
x[2]
julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]]
3-element Vector{Float64}:
1.0
2.0
4.0
julia> # Merge the two with the original.
varinfo_merged = merge(varinfo, varinfo_subset_merged);
julia> keys(varinfo_merged)
4-element Vector{VarName}:
s
m
x[1]
x[2]
julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
1.0
2.0
3.0
4.0
```
# Notes
## Type-stability
!!! warning
This function is only type-stable when `vns` contains only varnames
with the same symbol. For exmaple, `[@varname(m[1]), @varname(m[2])]` will
be type-stable, but `[@varname(m[1]), @varname(x)]` will not be.
"""
function subset end

"""
merge(varinfo, other_varinfos...)
Merge varinfos into one, giving precedence to the right-most varinfo when sensible.
This is particularly useful when combined with [`subset(varinfo, vns)`](@ref).
See docstring of [`subset(varinfo, vns)`](@ref) for examples.
"""
Base.merge(varinfo::AbstractVarInfo) = varinfo
# Define 3-argument version so 2-argument version will error if not implemented.
function Base.merge(
varinfo1::AbstractVarInfo,
varinfo2::AbstractVarInfo,
varinfo3::AbstractVarInfo,
varinfo_others::AbstractVarInfo...,
)
return merge(Base.merge(varinfo1, varinfo2), varinfo3, varinfo_others...)
end

# Transformations
"""
istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}])
Expand Down
45 changes: 45 additions & 0 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,51 @@ function Base.eltype(
return V
end

# `subset`
function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
return Setfield.@set varinfo.values = _subset(varinfo.values, vns)
end

function _subset(x::AbstractDict, vns)
# NOTE: This requires `vns` to be explicitly present in `x`.
if any(!Base.Fix1(haskey, x), vns)
throw(
ArgumentError(
"Cannot subset `AbstractDict` with `VarName` that is not an explicit key. " *
"For example, if `keys(x) == [@varname(x[1])]`, then subsetting with " *
"`@varname(x[1])` is allowed, but subsetting with `@varname(x)` is not.",
),
)
end
C = ConstructionBase.constructorof(typeof(x))
return C(vn => x[vn] for vn in vns)
end

function _subset(x::NamedTuple, vns)
# NOTE: Here we can only handle `vns` that contain the `IdentityLens`.
if any(Base.Fix1(!==, Setfield.IdentityLens()) getlens, vns)
throw(
ArgumentError(
"Cannot subset `NamedTuple` with non-`IdentityLens` `VarName`. " *
"For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.",
),
)
end

syms = map(getsym, vns)
return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix2(getindex, x), syms)))
end

# `merge`
function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
values = merge(varinfo_left.values, varinfo_right.values)
logp = getlogp(varinfo_right)
transformation = merge_transformations(
varinfo_left.transformation, varinfo_right.transformation
)
return SimpleVarInfo(values, logp, transformation)
end

# Context implementations
# NOTE: Evaluations, i.e. those without `rng` are shared with other
# implementations of `AbstractVarInfo`.
Expand Down
17 changes: 14 additions & 3 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,17 @@ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; isequal=isequal
end

"""
setup_varinfos(model::Model, example_values::NamedTuple, varnames)
setup_varinfos(model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false)
Return a tuple of instances for different implementations of `AbstractVarInfo` with
each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`.
If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions
of the varinfo instances.
"""
function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
function setup_varinfos(
model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false
)
# VarInfo
vi_untyped = VarInfo()
model(vi_untyped)
Expand All @@ -56,12 +61,18 @@ function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped)))

lp = getlogp(vi_typed)
return map((
varinfos = map((
vi_untyped, vi_typed, svi_typed, svi_untyped, svi_typed_ref, svi_untyped_ref
)) do vi
# Set them all to the same values.
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
end

if include_threadsafe
varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo deepcopy, varinfos)...)
end

return varinfos
end

"""
Expand Down
56 changes: 52 additions & 4 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,56 @@ islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl
function link!!(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return link!!(t, vi.varinfo, spl, model)
return Setfield.@set vi.varinfo = link!!(t, vi.varinfo, spl, model)
end

function invlink!!(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return invlink!!(t, vi.varinfo, spl, model)
return Setfield.@set vi.varinfo = invlink!!(t, vi.varinfo, spl, model)
end

function link(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return link(t, vi.varinfo, spl, model)
return Setfield.@set vi.varinfo = link(t, vi.varinfo, spl, model)
end

function invlink(
t::AbstractTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return invlink(t, vi.varinfo, spl, model)
return Setfield.@set vi.varinfo = invlink(t, vi.varinfo, spl, model)
end

# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity.
# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure
# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates
# to define `getlogp(vi)`.
function link!!(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t)
end

function invlink!!(
::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return settrans!!(
last(evaluate!!(model, vi, DynamicTransformationContext{true}())),
NoTransformation(),
)
end

function link(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return link!!(t, deepcopy(vi), spl, model)
end

function invlink(
t::DynamicTransformation, vi::ThreadSafeVarInfo, spl::AbstractSampler, model::Model
)
return invlink!!(t, deepcopy(vi), spl, model)
end

function maybe_invlink_before_eval!!(
Expand Down Expand Up @@ -192,3 +223,20 @@ istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn)
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)

getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn)

function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector)
return Setfield.@set vi.varinfo = unflatten(vi.varinfo, x)
end
function unflatten(vi::ThreadSafeVarInfo, spl::AbstractSampler, x::AbstractVector)
return Setfield.@set vi.varinfo = unflatten(vi.varinfo, spl, x)
end

function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName})
return Setfield.@set varinfo.varinfo = subset(varinfo.varinfo, vns)
end

function Base.merge(varinfo_left::ThreadSafeVarInfo, varinfo_right::ThreadSafeVarInfo)
return Setfield.@set varinfo_left.varinfo = merge(
varinfo_left.varinfo, varinfo_right.varinfo
)
end
Loading

0 comments on commit efd9da3

Please sign in to comment.