Skip to content

Commit

Permalink
Merge branch 'master' into torfjelde/rand-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Oct 24, 2023
2 parents de75531 + 12e7c27 commit 5de35af
Show file tree
Hide file tree
Showing 15 changed files with 863 additions and 40 deletions.
16 changes: 9 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.24.0"
version = "0.23.20"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -21,13 +22,20 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[compat]
AbstractMCMC = "2, 3.0, 4"
AbstractPPL = "0.6"
BangBang = "0.3"
Bijectors = "0.13"
ChainRulesCore = "0.9.7, 0.10, 1"
ConstructionBase = "1.5.4"
Compat = "4"
Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
LogDensityProblems = "2"
Expand All @@ -39,11 +47,5 @@ Setfield = "0.7.1, 0.8, 1"
ZygoteRules = "0.2"
julia = "1.6"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[extras]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
[compat]
DataStructures = "0.18"
Distributions = "0.25"
Documenter = "0.27"
Documenter = "1"
FillArrays = "0.13, 1"
LogDensityProblems = "2"
MCMCChains = "5, 6"
Expand Down
1 change: 0 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ makedocs(;
"API" => "api.md",
"Tutorials" => ["tutorials/prob-interface.md"],
],
strict=true,
checkdocs=:exports,
)

Expand Down
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ DynamicPPL.reconstruct
#### Utils

```@docs
Base.merge(::AbstractVarInfo)
DynamicPPL.subset
DynamicPPL.unflatten
DynamicPPL.tonamedtuple
DynamicPPL.varname_leaves
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module DynamicPPL
using AbstractMCMC: AbstractSampler, AbstractChains
using AbstractPPL
using Bijectors
using Compat
using Distributions
using OrderedCollections: OrderedDict

Expand Down Expand Up @@ -47,6 +48,7 @@ export AbstractVarInfo,
SimpleVarInfo,
push!!,
empty!!,
subset,
getlogp,
setlogp!!,
acclogp!!,
Expand Down
161 changes: 161 additions & 0 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,27 @@ struct StaticTransformation{F} <: AbstractTransformation
bijector::F
end

"""
merge_transformations(transformation_left, transformation_right)
Merge two transformations.
The main use of this is in [`merge(::AbstractVarInfo, ::AbstractVarInfo)`](@ref).
"""
function merge_transformations(::NoTransformation, ::NoTransformation)
return NoTransformation()
end
function merge_transformations(::DynamicTransformation, ::DynamicTransformation)
return DynamicTransformation()
end
function merge_transformations(left::StaticTransformation, right::StaticTransformation)
return StaticTransformation(merge_bijectors(left.bijector, right.bijector))
end

function merge_bijectors(left::Bijectors.NamedTransform, right::Bijectors.NamedTransform)
return Bijectors.NamedTransform(merge_bijector(left.bs, right.bs))
end

"""
default_transformation(model::Model[, vi::AbstractVarInfo])
Expand Down Expand Up @@ -337,6 +358,146 @@ function Base.eltype(vi::AbstractVarInfo, spl::Union{AbstractSampler,SampleFromP
return eltype(Core.Compiler.return_type(getindex, Tuple{typeof(vi),typeof(spl)}))
end

# TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert
# the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which
# might result in a `Vector{Any}`.
"""
subset(varinfo::AbstractVarInfo, vns::AbstractVector{<:VarName})
Subset a `varinfo` to only contain the variables `vns`.
!!! warning
The ordering of the variables in the resulting `varinfo` is _not_
guaranteed to follow the ordering of the variables in `varinfo`.
Hence care must be taken, in particular when used in conjunction with
other methods which uses the vector-representation of the `varinfo`,
e.g. `getindex(varinfo, sampler)`.
# Examples
```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL)
julia> @model function demo()
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
x = Vector{Float64}(undef, 2)
x[1] ~ Normal(m, sqrt(s))
x[2] ~ Normal(m, sqrt(s))
end
demo (generic function with 2 methods)
julia> model = demo();
julia> varinfo = VarInfo(model);
julia> keys(varinfo)
4-element Vector{VarName}:
s
m
x[1]
x[2]
julia> for (i, vn) in enumerate(keys(varinfo))
varinfo[vn] = i
end
julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
1.0
2.0
3.0
4.0
julia> # Extract one with only `m`.
varinfo_subset1 = subset(varinfo, [@varname(m),]);
julia> keys(varinfo_subset1)
1-element Vector{VarName{:m, Setfield.IdentityLens}}:
m
julia> varinfo_subset1[@varname(m)]
2.0
julia> # Extract one with both `s` and `x[2]`.
varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]);
julia> keys(varinfo_subset2)
2-element Vector{VarName}:
s
x[2]
julia> varinfo_subset2[[@varname(s), @varname(x[2])]]
2-element Vector{Float64}:
1.0
4.0
```
`subset` is particularly useful when combined with [`merge(varinfo::AbstractVarInfo)`](@ref)
```jldoctest varinfo-subset
julia> # Merge the two.
varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2);
julia> keys(varinfo_subset_merged)
3-element Vector{VarName}:
m
s
x[2]
julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]]
3-element Vector{Float64}:
1.0
2.0
4.0
julia> # Merge the two with the original.
varinfo_merged = merge(varinfo, varinfo_subset_merged);
julia> keys(varinfo_merged)
4-element Vector{VarName}:
s
m
x[1]
x[2]
julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]]
4-element Vector{Float64}:
1.0
2.0
3.0
4.0
```
# Notes
## Type-stability
!!! warning
This function is only type-stable when `vns` contains only varnames
with the same symbol. For exmaple, `[@varname(m[1]), @varname(m[2])]` will
be type-stable, but `[@varname(m[1]), @varname(x)]` will not be.
"""
function subset end

"""
merge(varinfo, other_varinfos...)
Merge varinfos into one, giving precedence to the right-most varinfo when sensible.
This is particularly useful when combined with [`subset(varinfo, vns)`](@ref).
See docstring of [`subset(varinfo, vns)`](@ref) for examples.
"""
Base.merge(varinfo::AbstractVarInfo) = varinfo
# Define 3-argument version so 2-argument version will error if not implemented.
function Base.merge(
varinfo1::AbstractVarInfo,
varinfo2::AbstractVarInfo,
varinfo3::AbstractVarInfo,
varinfo_others::AbstractVarInfo...,
)
return merge(Base.merge(varinfo1, varinfo2), varinfo3, varinfo_others...)
end

# Transformations
"""
istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}])
Expand Down
45 changes: 45 additions & 0 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,51 @@ function Base.eltype(
return V
end

# `subset`
function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
return Setfield.@set varinfo.values = _subset(varinfo.values, vns)
end

function _subset(x::AbstractDict, vns)
# NOTE: This requires `vns` to be explicitly present in `x`.
if any(!Base.Fix1(haskey, x), vns)
throw(
ArgumentError(
"Cannot subset `AbstractDict` with `VarName` that is not an explicit key. " *
"For example, if `keys(x) == [@varname(x[1])]`, then subsetting with " *
"`@varname(x[1])` is allowed, but subsetting with `@varname(x)` is not.",
),
)
end
C = ConstructionBase.constructorof(typeof(x))
return C(vn => x[vn] for vn in vns)
end

function _subset(x::NamedTuple, vns)
# NOTE: Here we can only handle `vns` that contain the `IdentityLens`.
if any(Base.Fix1(!==, Setfield.IdentityLens()) getlens, vns)
throw(
ArgumentError(
"Cannot subset `NamedTuple` with non-`IdentityLens` `VarName`. " *
"For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.",
),
)
end

syms = map(getsym, vns)
return NamedTuple{Tuple(syms)}(Tuple(map(Base.Fix2(getindex, x), syms)))
end

# `merge`
function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
values = merge(varinfo_left.values, varinfo_right.values)
logp = getlogp(varinfo_right)
transformation = merge_transformations(
varinfo_left.transformation, varinfo_right.transformation
)
return SimpleVarInfo(values, logp, transformation)
end

# Context implementations
# NOTE: Evaluations, i.e. those without `rng` are shared with other
# implementations of `AbstractVarInfo`.
Expand Down
17 changes: 14 additions & 3 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,17 @@ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; isequal=isequal
end

"""
setup_varinfos(model::Model, example_values::NamedTuple, varnames)
setup_varinfos(model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false)
Return a tuple of instances for different implementations of `AbstractVarInfo` with
each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`.
If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions
of the varinfo instances.
"""
function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
function setup_varinfos(
model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false
)
# VarInfo
vi_untyped = VarInfo()
model(vi_untyped)
Expand All @@ -56,12 +61,18 @@ function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped)))

lp = getlogp(vi_typed)
return map((
varinfos = map((
vi_untyped, vi_typed, svi_typed, svi_untyped, svi_typed_ref, svi_untyped_ref
)) do vi
# Set them all to the same values.
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
end

if include_threadsafe
varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo deepcopy, varinfos)...)
end

return varinfos
end

"""
Expand Down
Loading

0 comments on commit 5de35af

Please sign in to comment.