Skip to content

Commit

Permalink
Bump DynamicPPL to v0.25 (#2197)
Browse files Browse the repository at this point in the history
* reeclpad Setfield with Accessors to bump up to DPPL v0.25

* bump DPPL version

* use Accessors

* replaced usages of `@set!` with `BangBang.@set!!`

* fixed Project.toml

* reverted accidental change

* import BangBang in Turing

* replace `BangBang.@set!!` with `Accessors.@set`

* bump minor version since this is a breaking change

* makke failing test conditional on Julia version >1.7

* fixed references to Setfield.jl in Experimental module

* disabled another test due to the same issue

---------

Co-authored-by: Xianda Sun <[email protected]>
Co-authored-by: Xianda Sun <[email protected]>
  • Loading branch information
3 people authored Apr 23, 2024
1 parent a022dc6 commit 9be6b79
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 50 deletions.
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.30.9"
version = "0.31.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc"
Expand All @@ -29,7 +30,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
Expand All @@ -47,18 +47,19 @@ TuringOptimExt = "Optim"
[compat]
ADTypes = "0.2"
AbstractMCMC = "5.2"
Accessors = "0.1"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
AdvancedMH = "0.8"
AdvancedPS = "0.5.4"
AdvancedVI = "0.2"
BangBang = "0.3"
BangBang = "0.4"
Bijectors = "0.13.6"
DataStructures = "0.18"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.24.10"
DynamicPPL = "0.25.1"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
Expand All @@ -70,7 +71,6 @@ Optim = "1"
Reexport = "0.2, 1"
Requires = "0.5, 1.0"
SciMLBase = "1.37.1, 2"
Setfield = "0.8, 1"
SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2"
Statistics = "1.6"
StatsAPI = "1.6"
Expand Down
18 changes: 9 additions & 9 deletions ext/TuringOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ module TuringOptimExt

if isdefined(Base, :get_extension)
import Turing
import Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Setfield, Statistics, StatsAPI, StatsBase
import Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Accessors, Statistics, StatsAPI, StatsBase
import Optim
else
import ..Turing
import ..Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Setfield, Statistics, StatsAPI, StatsBase
import ..Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Accessors, Statistics, StatsAPI, StatsBase
import ..Optim
end

Expand Down Expand Up @@ -80,7 +80,7 @@ function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff
# Hessian is computed with respect to the untransformed parameters.
linked = DynamicPPL.istrans(m.f.varinfo)
if linked
Setfield.@set! m.f.varinfo = DynamicPPL.invlink!!(m.f.varinfo, m.f.model)
m = Accessors.@set m.f.varinfo = DynamicPPL.invlink!!(m.f.varinfo, m.f.model)
end

# Calculate the Hessian, which is the information matrix because the negative of the log likelihood was optimized
Expand All @@ -89,7 +89,7 @@ function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff

# Link it back if we invlinked it.
if linked
Setfield.@set! m.f.varinfo = DynamicPPL.link!!(m.f.varinfo, m.f.model)
m = Accessors.@set m.f.varinfo = DynamicPPL.link!!(m.f.varinfo, m.f.model)
end

return NamedArrays.NamedArray(info, (varnames, varnames))
Expand Down Expand Up @@ -227,8 +227,8 @@ function _optimize(
)
# Convert the initial values, since it is assumed that users provide them
# in the constrained space.
Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
Setfield.@set! f.varinfo = DynamicPPL.link(f.varinfo, model)
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals)
f = Accessors.@set f.varinfo = DynamicPPL.link(f.varinfo, model)
init_vals = DynamicPPL.getparams(f)

# Optimize!
Expand All @@ -241,10 +241,10 @@ function _optimize(

# Get the VarInfo at the MLE/MAP point, and run the model to ensure
# correct dimensionality.
Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
Setfield.@set! f.varinfo = DynamicPPL.invlink(f.varinfo, model)
f = Accessors.@set f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
f = Accessors.@set f.varinfo = DynamicPPL.invlink(f.varinfo, model)
vals = DynamicPPL.getparams(f)
Setfield.@set! f.varinfo = DynamicPPL.link(f.varinfo, model)
f = Accessors.@set f.varinfo = DynamicPPL.link(f.varinfo, model)

# Make one transition to get the parameter names.
ts = [Turing.Inference.Transition(
Expand Down
4 changes: 3 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ using DynamicPPL: DynamicPPL, LogDensityFunction
import DynamicPPL: getspace, NoDist, NamedDist
import LogDensityProblems
import NamedArrays
import Setfield
import Accessors
import StatsAPI
import StatsBase

using Accessors: Accessors

import Printf
import Random

Expand Down
2 changes: 1 addition & 1 deletion src/experimental/Experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Experimental
using Random: Random
using AbstractMCMC: AbstractMCMC
using DynamicPPL: DynamicPPL, VarName
using Setfield: Setfield
using Accessors: Accessors

using DocStringExtensions: TYPEDFIELDS
using Distributions
Expand Down
10 changes: 5 additions & 5 deletions src/experimental/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ Returns the preferred value type for a variable with the given `varinfo`.
preferred_value_type(::DynamicPPL.AbstractVarInfo) = DynamicPPL.OrderedDict
preferred_value_type(::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = NamedTuple
function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo)
# We can only do this in the scenario where all the varnames are `Setfield.IdentityLens`.
# We can only do this in the scenario where all the varnames are `Accessors.IdentityLens`.
namedtuple_compatible = all(varinfo.metadata) do md
eltype(md.vns) <: VarName{<:Any,Setfield.IdentityLens}
eltype(md.vns) <: VarName{<:Any,typeof(identity)}
end
return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict
end
Expand Down Expand Up @@ -321,8 +321,8 @@ function AbstractMCMC.step(
)

# Update the `states` and `varinfos`.
states = Setfield.setindex(states, new_state_local, index)
varinfos = Setfield.setindex(varinfos, new_varinfo_local, index)
states = Accessors.setindex(states, new_state_local, index)
varinfos = Accessors.setindex(varinfos, new_varinfo_local, index)
end

# Combine the resulting varinfo objects.
Expand All @@ -349,7 +349,7 @@ function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler
# NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide
# a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact
# same `selector` as before but now with `rerun` set to `true` if needed.
return Setfield.@set sampler.selector.rerun = true
return Accessors.@set sampler.selector.rerun = true
end

# Interface we need a sampler to implement to work as a component in a Gibbs sampler.
Expand Down
4 changes: 2 additions & 2 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ using DynamicPPL
using AbstractMCMC: AbstractModel, AbstractSampler
using DocStringExtensions: TYPEDEF, TYPEDFIELDS
using DataStructures: OrderedSet
using Setfield: Setfield
using Accessors: Accessors

import ADTypes
import AbstractMCMC
import AdvancedHMC; const AHMC = AdvancedHMC
import AdvancedMH; const AMH = AdvancedMH
import AdvancedPS
import BangBang
import Accessors
import EllipticalSliceSampling
import LogDensityProblems
import LogDensityProblemsAD
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.pa
getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo
getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper) = getvarinfo(parent(f))

setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo) = Setfield.@set f.varinfo = varinfo
setvarinfo(f::DynamicPPL.LogDensityFunction, varinfo) = Accessors.@set f.varinfo = varinfo
setvarinfo(f::LogDensityProblemsAD.ADGradientWrapper, varinfo) = setvarinfo(parent(f), varinfo)

# TODO: Do we also support `resume`, etc?
Expand Down
4 changes: 2 additions & 2 deletions src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Bijectors
using Random
using SciMLBase: OptimizationFunction, OptimizationProblem, AbstractADType, NoAD

using Setfield
using Accessors: Accessors
using DynamicPPL
using DynamicPPL: Model, AbstractContext, VarInfo, VarName,
_getindex, getsym, getfield, setorder!,
Expand Down Expand Up @@ -150,7 +150,7 @@ function transform!!(f::OptimLogDensity)
linked = DynamicPPL.istrans(f.varinfo)

## transform into constrained or unconstrained space depending on current state of vi
@set! f.varinfo = if !linked
f = Accessors.@set f.varinfo = if !linked
DynamicPPL.link!!(f.varinfo, f.model)
else
DynamicPPL.invlink!!(f.varinfo, f.model)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Clustering = "0.14, 0.15"
Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.24"
DynamicPPL = "0.25.1"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
LogDensityProblems = "2"
Expand Down
20 changes: 12 additions & 8 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,18 @@
end
end

@turing_testset "(partially) issue: #2095" begin
@model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV}
xs = Vector{TV}(undef, 2)
xs[1] ~ Dirichlet(ones(5))
xs[2] ~ Dirichlet(ones(5))
# Disable on Julia <1.8 due to https://github.com/TuringLang/Turing.jl/pull/2197.
# TODO: Remove this block once https://github.com/JuliaFolds2/BangBang.jl/pull/22 has been released.
if VERSION v"1.8"
@turing_testset "(partially) issue: #2095" begin
@model function vector_of_dirichlet(::Type{TV}=Vector{Float64}) where {TV}
xs = Vector{TV}(undef, 2)
xs[1] ~ Dirichlet(ones(5))
xs[2] ~ Dirichlet(ones(5))
end
model = vector_of_dirichlet()
chain = sample(model, NUTS(), 1000)
@test mean(Array(chain)) 0.2
end
model = vector_of_dirichlet()
chain = sample(model, NUTS(), 1000)
@test mean(Array(chain)) 0.2
end
end
34 changes: 19 additions & 15 deletions test/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,24 +162,28 @@
# @test v1 < v2
end

@turing_testset "vector of multivariate distributions" begin
@model function test(k)
T = Vector{Vector{Float64}}(undef, k)
for i in 1:k
T[i] ~ Dirichlet(5, 1.0)
# Disable on Julia <1.8 due to https://github.com/TuringLang/Turing.jl/pull/2197.
# TODO: Remove this block once https://github.com/JuliaFolds2/BangBang.jl/pull/22 has been released.
if VERSION v"1.8"
@turing_testset "vector of multivariate distributions" begin
@model function test(k)
T = Vector{Vector{Float64}}(undef, k)
for i in 1:k
T[i] ~ Dirichlet(5, 1.0)
end
end
end

Random.seed!(100)
chain = sample(test(1), MH(), 5_000)
for i in 1:5
@test mean(chain, "T[1][$i]") 0.2 atol=0.01
end
Random.seed!(100)
chain = sample(test(1), MH(), 5_000)
for i in 1:5
@test mean(chain, "T[1][$i]") 0.2 atol = 0.01
end

Random.seed!(100)
chain = sample(test(10), MH(), 5_000)
for j in 1:10, i in 1:5
@test mean(chain, "T[$j][$i]") 0.2 atol=0.01
Random.seed!(100)
chain = sample(test(10), MH(), 5_000)
for j in 1:10, i in 1:5
@test mean(chain, "T[$j][$i]") 0.2 atol = 0.01
end
end
end

Expand Down

2 comments on commit 9be6b79

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Breaking

  • DynamicPPL.jl, the package defining the @model macro, etc., is now using Accessors.jl instead of Setfield.jl.
  • Indexable variables will now show up slightly different in the chains, e.g. x[:,1] will be x[:, 1]

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/105466

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.31.0 -m "<description of version>" 9be6b792112c8fff36051c5732aa95be2b43ac29
git push origin v0.31.0

Please sign in to comment.