From 94bc38e6c4760b85e92b8c68c0e61d435ddd7f71 Mon Sep 17 00:00:00 2001 From: jaimerz Date: Thu, 6 Jul 2023 11:33:24 +0100 Subject: [PATCH 01/28] first draft --- src/inference/Inference.jl | 51 +++++++++++++++++++++++++++++++++++++- src/inference/mh.jl | 10 ++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 4e1f35254..73ce55e72 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -68,7 +68,8 @@ export InferenceAlgorithm, resume, predict, isgibbscomponent, - externalsampler + externalsampler, + extract_priors ####################### # Sampler abstraction # @@ -485,6 +486,54 @@ end # Utilities # ############## +Base.@kwdef struct PriorExtractorContext{D,Ctx} <: AbstractContext + priors::D=OrderedDict{VarName,Any}() + context::Ctx=SamplingContext() +end + +NodeTrait(::PriorExtractorContext) = IsParent() +childcontext(context::PriorExtractorContext) = context.context +setchildcontext(parent::PriorExtractorContext, child) = PriorExtractorContext(parent.priors, child) + +function DynamicPPL.tilde_assume(context::PriorExtractorContext, right, vn, vi) + setprior!(context, vn, right) + return DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) +end + +function DynamicPPL.dot_tilde_assume(context::PriorExtractorContext, right, left, vn, vi) + setprior!(context, vn, right) + return DynamicPPL.dot_tilde_assume(childcontext(context), right, left, vn, vi) +end + +function setprior!(context::PriorExtractorContext, vn::VarName, dist::Distribution) + context.priors[vn] = dist +end + +function setprior!(context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dist::Distribution) + for vn in vns + context.priors[vn] = dist + end +end + +function setprior!(context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dists::AbstractArray{<:Distribution}) + # TODO: Support broadcasted expressions properly. + for (vn, dist) in zip(vns, dists) + context.priors[vn] = dist + end +end + +""" + extract_priors(model::Model) + +Extract the priors from a model. This is done by sampling from the model and +recording the distributions that are used to generate the samples. +""" +function extract_priors(model::Model) + context = PriorExtractorContext() + evaluate!!(model, VarInfo(), context) + return context.priors +end + DynamicPPL.getspace(spl::Sampler) = getspace(spl.alg) DynamicPPL.inspace(vn::VarName, spl::Sampler) = inspace(vn, getspace(spl.alg)) diff --git a/src/inference/mh.jl b/src/inference/mh.jl index ddbeaa2c5..66e5a84f7 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -188,6 +188,16 @@ function MH(space...) return MH{tuple(syms...), typeof(proposals)}(proposals) end +function StaticMH(model::Model) + priors = extract_priors(model) + return AMH.MetropolisHastings(AMH.StaticProposal(priors)) +end + +function RWMH(model::Model) + priors = extract_priors(model) + return AMH.MetropolisHastings(AMH.RandomWalkProposal(priors)) +end + ##################### # Utility functions # ##################### From 26f5bace0ec1e7d7c4daf16eebe4480077b5440c Mon Sep 17 00:00:00 2001 From: jaimerz Date: Thu, 6 Jul 2023 15:43:53 +0100 Subject: [PATCH 02/28] abstractcontext + tests --- src/inference/Inference.jl | 2 +- test/inference/mh.jl | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 73ce55e72..501e0de33 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -11,7 +11,7 @@ using DynamicPPL: Metadata, VarInfo, TypedVarInfo, Model, Sampler, SampleFromPrior, SampleFromUniform, DefaultContext, PriorContext, LikelihoodContext, set_flag!, unset_flag!, - getspace, inspace + getspace, inspace, AbstractContext using Distributions, Libtask, Bijectors using DistributionsAD: VectorOfMultivariate using LinearAlgebra diff --git a/test/inference/mh.jl b/test/inference/mh.jl index 8e52aec9b..38fe9fc18 100644 --- a/test/inference/mh.jl +++ b/test/inference/mh.jl @@ -17,6 +17,9 @@ s4 = Gibbs(MH(:m), MH(:s)) c4 = sample(gdemo_default, s4, N) + + s5 = MH(gdemo_default) + c5 = sample(gdemo_default, s5, N) end @numerical_testset "mh inference" begin Random.seed!(125) From b31908f211dc103efa35881d961acb91fd160b21 Mon Sep 17 00:00:00 2001 From: jaimerz Date: Thu, 6 Jul 2023 16:56:04 +0100 Subject: [PATCH 03/28] bug --- test/inference/mh.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/inference/mh.jl b/test/inference/mh.jl index 38fe9fc18..2568fbcb8 100644 --- a/test/inference/mh.jl +++ b/test/inference/mh.jl @@ -18,8 +18,11 @@ s4 = Gibbs(MH(:m), MH(:s)) c4 = sample(gdemo_default, s4, N) - s5 = MH(gdemo_default) + s5 = RWMH(gdemo_default) c5 = sample(gdemo_default, s5, N) + + s6 = StaticMH(gdemo_default) + c6 = sample(gdemo_default, s6, N) end @numerical_testset "mh inference" begin Random.seed!(125) From af5018815806dd73c3febb9ed0e8932a920f27cb Mon Sep 17 00:00:00 2001 From: jaimerz Date: Fri, 7 Jul 2023 09:43:55 +0100 Subject: [PATCH 04/28] externalsampler() in tests --- test/inference/mh.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/inference/mh.jl b/test/inference/mh.jl index 2568fbcb8..eee86d4a0 100644 --- a/test/inference/mh.jl +++ b/test/inference/mh.jl @@ -18,10 +18,10 @@ s4 = Gibbs(MH(:m), MH(:s)) c4 = sample(gdemo_default, s4, N) - s5 = RWMH(gdemo_default) + s5 = externalsampler(RWMH(gdemo_default)) c5 = sample(gdemo_default, s5, N) - s6 = StaticMH(gdemo_default) + s6 = externalsampler(StaticMH(gdemo_default)) c6 = sample(gdemo_default, s6, N) end @numerical_testset "mh inference" begin From 6725e4a887275201ed0edc93492316f8da4a2186 Mon Sep 17 00:00:00 2001 From: jaimerz Date: Fri, 7 Jul 2023 17:37:49 +0100 Subject: [PATCH 05/28] Name Tupple problems --- src/Turing.jl | 2 ++ src/inference/Inference.jl | 7 ++++++- src/inference/mh.jl | 12 +++++------- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/Turing.jl b/src/Turing.jl index 501f11290..91717489e 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -85,6 +85,7 @@ export @model, # modelling MH, # classic sampling RWMH, + StaticMH, Emcee, ESS, Gibbs, @@ -113,6 +114,7 @@ export @model, # modelling @logprob_str, @prob_str, externalsampler, + extract_priors, setchunksize, # helper setadbackend, diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 501e0de33..96d973481 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -11,7 +11,9 @@ using DynamicPPL: Metadata, VarInfo, TypedVarInfo, Model, Sampler, SampleFromPrior, SampleFromUniform, DefaultContext, PriorContext, LikelihoodContext, set_flag!, unset_flag!, - getspace, inspace, AbstractContext + getspace, inspace, AbstractContext, + evaluate!!, IsParent, OrderedDict, + SamplingContext, Distribution using Distributions, Libtask, Bijectors using DistributionsAD: VectorOfMultivariate using LinearAlgebra @@ -29,6 +31,7 @@ import AdvancedHMC; const AHMC = AdvancedHMC import AdvancedMH; const AMH = AdvancedMH import AdvancedPS import BangBang +import DynamicPPL: tilde_assume, dot_tilde_assume, childcontext, setchildcontext, NodeTrait import ..Essential: getADbackend import EllipticalSliceSampling import LogDensityProblems @@ -45,6 +48,8 @@ export InferenceAlgorithm, SampleFromUniform, SampleFromPrior, MH, + RWMH, + StaticMH, ESS, Emcee, Gibbs, # classic sampling diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 66e5a84f7..74f8e6eed 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -188,14 +188,12 @@ function MH(space...) return MH{tuple(syms...), typeof(proposals)}(proposals) end -function StaticMH(model::Model) +function MH(model::Model; proposal_type=AMH.StaticProposal) priors = extract_priors(model) - return AMH.MetropolisHastings(AMH.StaticProposal(priors)) -end - -function RWMH(model::Model) - priors = extract_priors(model) - return AMH.MetropolisHastings(AMH.RandomWalkProposal(priors)) + props = Tuple([proposal_type(prop) for prop in values(priors)]) + vars = Symbol.(keys(priors)) + priors = NamedTuple{Tuple(vars)}(props) + return AMH.MetropolisHastings(priors) end ##################### From 427bef41fd0946b235570e1b7f9cbeab20e65fe6 Mon Sep 17 00:00:00 2001 From: jaimerz Date: Tue, 11 Jul 2023 11:47:56 +0200 Subject: [PATCH 06/28] moving stuff to DynamicPPL RP --- src/inference/Inference.jl | 36 ------------------------------------ 1 file changed, 36 deletions(-) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 96d973481..4a6dd2934 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -491,42 +491,6 @@ end # Utilities # ############## -Base.@kwdef struct PriorExtractorContext{D,Ctx} <: AbstractContext - priors::D=OrderedDict{VarName,Any}() - context::Ctx=SamplingContext() -end - -NodeTrait(::PriorExtractorContext) = IsParent() -childcontext(context::PriorExtractorContext) = context.context -setchildcontext(parent::PriorExtractorContext, child) = PriorExtractorContext(parent.priors, child) - -function DynamicPPL.tilde_assume(context::PriorExtractorContext, right, vn, vi) - setprior!(context, vn, right) - return DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) -end - -function DynamicPPL.dot_tilde_assume(context::PriorExtractorContext, right, left, vn, vi) - setprior!(context, vn, right) - return DynamicPPL.dot_tilde_assume(childcontext(context), right, left, vn, vi) -end - -function setprior!(context::PriorExtractorContext, vn::VarName, dist::Distribution) - context.priors[vn] = dist -end - -function setprior!(context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dist::Distribution) - for vn in vns - context.priors[vn] = dist - end -end - -function setprior!(context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dists::AbstractArray{<:Distribution}) - # TODO: Support broadcasted expressions properly. - for (vn, dist) in zip(vns, dists) - context.priors[vn] = dist - end -end - """ extract_priors(model::Model) From c0227652d65c4b27ea94e3234b8ffbc6c59eefa3 Mon Sep 17 00:00:00 2001 From: jaimerz Date: Fri, 14 Jul 2023 11:06:03 +0200 Subject: [PATCH 07/28] using new DynamicPPL PR --- src/Turing.jl | 1 - src/inference/Inference.jl | 22 +++------------------- 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/src/Turing.jl b/src/Turing.jl index 91717489e..fc29cb4e9 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -114,7 +114,6 @@ export @model, # modelling @logprob_str, @prob_str, externalsampler, - extract_priors, setchunksize, # helper setadbackend, diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 4a6dd2934..a1e9f4bcb 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -11,9 +11,7 @@ using DynamicPPL: Metadata, VarInfo, TypedVarInfo, Model, Sampler, SampleFromPrior, SampleFromUniform, DefaultContext, PriorContext, LikelihoodContext, set_flag!, unset_flag!, - getspace, inspace, AbstractContext, - evaluate!!, IsParent, OrderedDict, - SamplingContext, Distribution + getspace, inspace using Distributions, Libtask, Bijectors using DistributionsAD: VectorOfMultivariate using LinearAlgebra @@ -31,8 +29,7 @@ import AdvancedHMC; const AHMC = AdvancedHMC import AdvancedMH; const AMH = AdvancedMH import AdvancedPS import BangBang -import DynamicPPL: tilde_assume, dot_tilde_assume, childcontext, setchildcontext, NodeTrait -import ..Essential: getADbackend +import DynamicPPL: extract_priors import EllipticalSliceSampling import LogDensityProblems import LogDensityProblemsAD @@ -73,8 +70,7 @@ export InferenceAlgorithm, resume, predict, isgibbscomponent, - externalsampler, - extract_priors + externalsampler ####################### # Sampler abstraction # @@ -491,18 +487,6 @@ end # Utilities # ############## -""" - extract_priors(model::Model) - -Extract the priors from a model. This is done by sampling from the model and -recording the distributions that are used to generate the samples. -""" -function extract_priors(model::Model) - context = PriorExtractorContext() - evaluate!!(model, VarInfo(), context) - return context.priors -end - DynamicPPL.getspace(spl::Sampler) = getspace(spl.alg) DynamicPPL.inspace(vn::VarName, spl::Sampler) = inspace(vn, getspace(spl.alg)) From 9ea35507c7c22fcc4513478e2627db0e68db47d6 Mon Sep 17 00:00:00 2001 From: jaimerz Date: Fri, 14 Jul 2023 11:13:20 +0200 Subject: [PATCH 08/28] mistakenly removed line --- src/inference/Inference.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index a1e9f4bcb..02fb5a17b 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -29,6 +29,7 @@ import AdvancedHMC; const AHMC = AdvancedHMC import AdvancedMH; const AMH = AdvancedMH import AdvancedPS import BangBang +import ..Essential: getADbackend import DynamicPPL: extract_priors import EllipticalSliceSampling import LogDensityProblems From bdf51104542e57ff7587b56377ccb045f14a4ec3 Mon Sep 17 00:00:00 2001 From: jaimerz Date: Fri, 14 Jul 2023 11:21:45 +0200 Subject: [PATCH 09/28] specific constructors --- src/inference/mh.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 74f8e6eed..9fd6c99f7 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -188,6 +188,14 @@ function MH(space...) return MH{tuple(syms...), typeof(proposals)}(proposals) end +function StaticMH(model::Model) + return MH(model; proposal_type=AMH.StaticProposal) +end + +function RWMH(model::Model) + return MH(model; proposal_type=AMH.RandomWalkProposal) +end + function MH(model::Model; proposal_type=AMH.StaticProposal) priors = extract_priors(model) props = Tuple([proposal_type(prop) for prop in values(priors)]) From 4ab5939f48225f57f13dc7184052e488e56112da Mon Sep 17 00:00:00 2001 From: jaimerz Date: Mon, 17 Jul 2023 10:55:03 +0100 Subject: [PATCH 10/28] no StaticMH RWMH --- src/inference/mh.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 9fd6c99f7..74f8e6eed 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -188,14 +188,6 @@ function MH(space...) return MH{tuple(syms...), typeof(proposals)}(proposals) end -function StaticMH(model::Model) - return MH(model; proposal_type=AMH.StaticProposal) -end - -function RWMH(model::Model) - return MH(model; proposal_type=AMH.RandomWalkProposal) -end - function MH(model::Model; proposal_type=AMH.StaticProposal) priors = extract_priors(model) props = Tuple([proposal_type(prop) for prop in values(priors)]) From 74869f4b825f9b99c75a2fc4bc7638acb4a4da25 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Thu, 27 Jul 2023 18:27:01 +0100 Subject: [PATCH 11/28] Bump bijectors compat (#2052) * CompatHelper: bump compat for Bijectors to 0.13, (keep existing compat) * Update Project.toml * Replacement for #2039 (#2040) * Fix testset for external samplers * Update abstractmcmc.jl * Update test/contrib/inference/abstractmcmc.jl Co-authored-by: Tor Erlend Fjelde * Update test/contrib/inference/abstractmcmc.jl Co-authored-by: Tor Erlend Fjelde * Update FillArrays compat to 1.4.1 (#2035) * Update FillArrays compat to 1.4.0 * Update test compat * Try to enable ReverseDiff tests * Update Project.toml * Update Project.toml * Bump version * Revert dependencies on FillArrays (#2042) * Update Project.toml * Update Project.toml * Fix redundant definition of `getstats` (#2044) * Fix redundant definition of `getstats` * Update Inference.jl * Revert "Update Inference.jl" This reverts commit e4f51c24fa7450d625d18b21ca3a273bb2d736d0. * Bump version --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * Transfer some test utility function into DynamicPPL (#2049) * Update OptimInterface.jl * Only run optimisation tests in numerical stage. * fix function lookup after moving functions --------- Co-authored-by: Xianda Sun * Move Optim support to extension (#2051) * Move Optim support to extension * More imports * Update Project.toml --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --------- Co-authored-by: CompatHelper Julia Co-authored-by: haris organtzidis Co-authored-by: Tor Erlend Fjelde Co-authored-by: David Widmann Co-authored-by: Xianda Sun Co-authored-by: Cameron Pfiffer --- Project.toml | 16 +++- .../TuringOptimExt.jl | 92 ++++++++++--------- src/Turing.jl | 41 +++++---- src/contrib/inference/abstractmcmc.jl | 1 - test/Project.toml | 4 +- test/contrib/inference/abstractmcmc.jl | 10 +- test/essential/ad.jl | 9 +- test/modes/OptimInterface.jl | 41 +-------- 8 files changed, 97 insertions(+), 117 deletions(-) rename src/modes/OptimInterface.jl => ext/TuringOptimExt.jl (69%) diff --git a/Project.toml b/Project.toml index a4cc4686a..f4fb14669 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.26.4" +version = "0.27" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -16,7 +16,6 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2" -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -44,20 +43,20 @@ AdvancedMH = "0.6.8, 0.7" AdvancedPS = "0.4" AdvancedVI = "0.2" BangBang = "0.3" -Bijectors = "0.12" +Bijectors = "0.13.2" DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicPPL = "0.23" EllipticalSliceSampling = "0.5, 1" -FillArrays = "=1.0.0" ForwardDiff = "0.10.3" Libtask = "0.7, 0.8" LogDensityProblems = "2" LogDensityProblemsAD = "1.4" MCMCChains = "5, 6" NamedArrays = "0.9" +Optim = "1" Reexport = "0.2, 1" Requires = "0.5, 1.0" SciMLBase = "1.37.1" @@ -68,3 +67,12 @@ StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" Tracker = "0.2.3" julia = "1.7" + +[weakdeps] +Optim = "429524aa-4258-5aef-a3af-852621145aeb" + +[extensions] +TuringOptimExt = "Optim" + +[extras] +Optim = "429524aa-4258-5aef-a3af-852621145aeb" diff --git a/src/modes/OptimInterface.jl b/ext/TuringOptimExt.jl similarity index 69% rename from src/modes/OptimInterface.jl rename to ext/TuringOptimExt.jl index f477fedb9..714cea202 100644 --- a/src/modes/OptimInterface.jl +++ b/ext/TuringOptimExt.jl @@ -1,14 +1,14 @@ -using Setfield -using DynamicPPL: DefaultContext, LikelihoodContext -using DynamicPPL: DynamicPPL -import .Optim -import .Optim: optimize -import ..ForwardDiff -import NamedArrays -import StatsBase -import Printf -import StatsAPI - +module TuringOptimExt + +if isdefined(Base, :get_extension) + import Turing + import Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Setfield, Statistics, StatsAPI, StatsBase + import Optim +else + import ..Turing + import ..Turing: Distributions, DynamicPPL, ForwardDiff, NamedArrays, Printf, Setfield, Statistics, StatsAPI, StatsBase + import ..Optim +end """ ModeResult{ @@ -23,7 +23,7 @@ A wrapper struct to store various results from a MAP or MLE estimation. struct ModeResult{ V<:NamedArrays.NamedArray, O<:Optim.MultivariateOptimizationResults, - M<:OptimLogDensity + M<:Turing.OptimLogDensity } <: StatsBase.StatisticalModel "A vector with the resulting point estimates." values::V @@ -57,10 +57,10 @@ function StatsBase.coeftable(m::ModeResult; level::Real=0.95) estimates = m.values.array[:, 1] stderrors = StatsBase.stderror(m) zscore = estimates ./ stderrors - p = map(z -> StatsAPI.pvalue(Normal(), z; tail=:both), zscore) + p = map(z -> StatsAPI.pvalue(Distributions.Normal(), z; tail=:both), zscore) # Confidence interval (CI) - q = quantile(Normal(), (1 + level) / 2) + q = Statistics.quantile(Distributions.Normal(), (1 + level) / 2) ci_low = estimates .- q .* stderrors ci_high = estimates .+ q .* stderrors @@ -80,7 +80,7 @@ function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff # Hessian is computed with respect to the untransformed parameters. linked = DynamicPPL.istrans(m.f.varinfo) if linked - @set! m.f.varinfo = invlink!!(m.f.varinfo, m.f.model) + Setfield.@set! m.f.varinfo = DynamicPPL.invlink!!(m.f.varinfo, m.f.model) end # Calculate the Hessian. @@ -90,7 +90,7 @@ function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff # Link it back if we invlinked it. if linked - @set! m.f.varinfo = link!!(m.f.varinfo, m.f.model) + Setfield.@set! m.f.varinfo = DynamicPPL.link!!(m.f.varinfo, m.f.model) end return NamedArrays.NamedArray(info, (varnames, varnames)) @@ -126,18 +126,18 @@ mle = optimize(model, MLE()) mle = optimize(model, MLE(), NelderMead()) ``` """ -function Optim.optimize(model::Model, ::MLE, options::Optim.Options=Optim.Options(); kwargs...) +function Optim.optimize(model::DynamicPPL.Model, ::Turing.MLE, options::Optim.Options=Optim.Options(); kwargs...) return _mle_optimize(model, options; kwargs...) end -function Optim.optimize(model::Model, ::MLE, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...) +function Optim.optimize(model::DynamicPPL.Model, ::Turing.MLE, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...) return _mle_optimize(model, init_vals, options; kwargs...) end -function Optim.optimize(model::Model, ::MLE, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); kwargs...) +function Optim.optimize(model::DynamicPPL.Model, ::Turing.MLE, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); kwargs...) return _mle_optimize(model, optimizer, options; kwargs...) end function Optim.optimize( - model::Model, - ::MLE, + model::DynamicPPL.Model, + ::Turing.MLE, init_vals::AbstractArray, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); @@ -146,9 +146,9 @@ function Optim.optimize( return _mle_optimize(model, init_vals, optimizer, options; kwargs...) end -function _mle_optimize(model::Model, args...; kwargs...) - ctx = OptimizationContext(DynamicPPL.LikelihoodContext()) - return _optimize(model, OptimLogDensity(model, ctx), args...; kwargs...) +function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...) + ctx = Turing.OptimizationContext(DynamicPPL.LikelihoodContext()) + return _optimize(model, Turing.OptimLogDensity(model, ctx), args...; kwargs...) end """ @@ -172,18 +172,18 @@ map_est = optimize(model, MAP(), NelderMead()) ``` """ -function Optim.optimize(model::Model, ::MAP, options::Optim.Options=Optim.Options(); kwargs...) +function Optim.optimize(model::DynamicPPL.Model, ::Turing.MAP, options::Optim.Options=Optim.Options(); kwargs...) return _map_optimize(model, options; kwargs...) end -function Optim.optimize(model::Model, ::MAP, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...) +function Optim.optimize(model::DynamicPPL.Model, ::Turing.MAP, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...) return _map_optimize(model, init_vals, options; kwargs...) end -function Optim.optimize(model::Model, ::MAP, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); kwargs...) +function Optim.optimize(model::DynamicPPL.Model, ::Turing.MAP, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); kwargs...) return _map_optimize(model, optimizer, options; kwargs...) end function Optim.optimize( - model::Model, - ::MAP, + model::DynamicPPL.Model, + ::Turing.MAP, init_vals::AbstractArray, optimizer::Optim.AbstractOptimizer, options::Optim.Options=Optim.Options(); @@ -192,9 +192,9 @@ function Optim.optimize( return _map_optimize(model, init_vals, optimizer, options; kwargs...) end -function _map_optimize(model::Model, args...; kwargs...) - ctx = OptimizationContext(DynamicPPL.DefaultContext()) - return _optimize(model, OptimLogDensity(model, ctx), args...; kwargs...) +function _map_optimize(model::DynamicPPL.Model, args...; kwargs...) + ctx = Turing.OptimizationContext(DynamicPPL.DefaultContext()) + return _optimize(model, Turing.OptimLogDensity(model, ctx), args...; kwargs...) end """ @@ -203,8 +203,8 @@ end Estimate a mode, i.e., compute a MLE or MAP estimate. """ function _optimize( - model::Model, - f::OptimLogDensity, + model::DynamicPPL.Model, + f::Turing.OptimLogDensity, optimizer::Optim.AbstractOptimizer=Optim.LBFGS(), args...; kwargs... @@ -213,8 +213,8 @@ function _optimize( end function _optimize( - model::Model, - f::OptimLogDensity, + model::DynamicPPL.Model, + f::Turing.OptimLogDensity, options::Optim.Options=Optim.Options(), args...; kwargs... @@ -223,8 +223,8 @@ function _optimize( end function _optimize( - model::Model, - f::OptimLogDensity, + model::DynamicPPL.Model, + f::Turing.OptimLogDensity, init_vals::AbstractArray=DynamicPPL.getparams(f), options::Optim.Options=Optim.Options(), args...; @@ -234,8 +234,8 @@ function _optimize( end function _optimize( - model::Model, - f::OptimLogDensity, + model::DynamicPPL.Model, + f::Turing.OptimLogDensity, init_vals::AbstractArray=DynamicPPL.getparams(f), optimizer::Optim.AbstractOptimizer=Optim.LBFGS(), options::Optim.Options=Optim.Options(), @@ -244,8 +244,8 @@ function _optimize( ) # Convert the initial values, since it is assumed that users provide them # in the constrained space. - @set! f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals) - @set! f.varinfo = DynamicPPL.link!!(f.varinfo, model) + Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, init_vals) + Setfield.@set! f.varinfo = DynamicPPL.link!!(f.varinfo, model) init_vals = DynamicPPL.getparams(f) # Optimize! @@ -258,10 +258,10 @@ function _optimize( # Get the VarInfo at the MLE/MAP point, and run the model to ensure # correct dimensionality. - @set! f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer) - @set! f.varinfo = invlink!!(f.varinfo, model) + Setfield.@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer) + Setfield.@set! f.varinfo = DynamicPPL.invlink!!(f.varinfo, model) vals = DynamicPPL.getparams(f) - @set! f.varinfo = link!!(f.varinfo, model) + Setfield.@set! f.varinfo = DynamicPPL.link!!(f.varinfo, model) # Make one transition to get the parameter names. ts = [Turing.Inference.Transition( @@ -275,3 +275,5 @@ function _optimize( return ModeResult(vmat, M, -M.minimum, f) end + +end # module diff --git a/src/Turing.jl b/src/Turing.jl index fc29cb4e9..b4ebef19d 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -11,6 +11,12 @@ import AdvancedVI using DynamicPPL: DynamicPPL, LogDensityFunction import DynamicPPL: getspace, NoDist, NamedDist import LogDensityProblems +import NamedArrays +import Setfield +import StatsAPI +import StatsBase + +import Printf import Random const PROGRESS = Ref(true) @@ -48,26 +54,9 @@ using .Inference include("variational/VariationalInference.jl") using .Variational -@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" begin - @eval Inference begin - import ..DynamicHMC - - if isdefined(DynamicHMC, :mcmc_with_warmup) - include("contrib/inference/dynamichmc.jl") - else - error("Please update DynamicHMC, v1.x is no longer supported") - end - end -end - include("modes/ModeEstimation.jl") using .ModeEstimation -@init @require Optim="429524aa-4258-5aef-a3af-852621145aeb" @eval begin - include("modes/OptimInterface.jl") - export optimize -end - ########### # Exports # ########### @@ -146,4 +135,22 @@ export @model, # modelling optim_objective, optim_function, optim_problem + +function __init__() + @static if !isdefined(Base, :get_extension) + @require Optim="429524aa-4258-5aef-a3af-852621145aeb" include("../ext/TuringOptimExt.jl") + end + @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" begin + @eval Inference begin + import ..DynamicHMC + + if isdefined(DynamicHMC, :mcmc_with_warmup) + include("contrib/inference/dynamichmc.jl") + else + error("Please update DynamicHMC, v1.x is no longer supported") + end + end + end +end + end diff --git a/src/contrib/inference/abstractmcmc.jl b/src/contrib/inference/abstractmcmc.jl index 19411ac99..0f158adfc 100644 --- a/src/contrib/inference/abstractmcmc.jl +++ b/src/contrib/inference/abstractmcmc.jl @@ -19,7 +19,6 @@ getparams(transition::AdvancedHMC.Transition) = transition.z.θ getstats(transition::AdvancedHMC.Transition) = transition.stat getparams(transition::AdvancedMH.Transition) = transition.params -getstats(transition) = NamedTuple() getvarinfo(f::DynamicPPL.LogDensityFunction) = f.varinfo getvarinfo(f::LogDensityProblemsAD.ADGradientWrapper) = getvarinfo(parent(f)) diff --git a/test/Project.toml b/test/Project.toml index cec3c0d4b..470d50afe 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,7 +8,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -42,14 +41,13 @@ Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.23" -FillArrays = "=1.0.0" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" LogDensityProblems = "2" LogDensityProblemsAD = "1.4" MCMCChains = "5, 6" NamedArrays = "0.9.4" -Optim = "0.22, 1.0" +Optim = "1" Optimization = "3.5" OptimizationOptimJL = "0.1" PDMats = "0.10, 0.11" diff --git a/test/contrib/inference/abstractmcmc.jl b/test/contrib/inference/abstractmcmc.jl index 691130635..ca7e1b3ec 100644 --- a/test/contrib/inference/abstractmcmc.jl +++ b/test/contrib/inference/abstractmcmc.jl @@ -41,7 +41,7 @@ function initialize_mh(model) end @testset "External samplers" begin - @testset "AdvancedHMC.jl" begin + @turing_testset "AdvancedHMC.jl" begin for model in DynamicPPL.TestUtils.DEMO_MODELS # Need some functionality to initialize the sampler. # TODO: Remove this once the constructors in the respective packages become "lazy". @@ -52,12 +52,13 @@ end 5_000; nadapts=1_000, discard_initial=1_000, - rtol=0.2 + rtol=0.2, + sampler_name="AdvancedHMC" ) end end - @testset "AdvancedMH.jl" begin + @turing_testset "AdvancedMH.jl" begin for model in DynamicPPL.TestUtils.DEMO_MODELS # Need some functionality to initialize the sampler. # TODO: Remove this once the constructors in the respective packages become "lazy". @@ -68,7 +69,8 @@ end 10_000; discard_initial=1_000, thinning=10, - rtol=0.2 + rtol=0.2, + sampler_name="AdvancedMH" ) end end diff --git a/test/essential/ad.jl b/test/essential/ad.jl index 0a60a3e0f..8a0241a83 100644 --- a/test/essential/ad.jl +++ b/test/essential/ad.jl @@ -95,15 +95,14 @@ sample(dir(), HMC(0.01, 1), 1000) Turing.setrdcache(false) end - # FIXME: For some reasons PDMatDistribution AD tests fail with ReverseDiff @testset "PDMatDistribution AD" begin @model function wishart() theta ~ Wishart(4, Matrix{Float64}(I, 4, 4)) end Turing.setadbackend(:tracker) sample(wishart(), HMC(0.01, 1), 1000); - #Turing.setadbackend(:reversediff) - #sample(wishart(), HMC(0.01, 1), 1000); + Turing.setadbackend(:reversediff) + sample(wishart(), HMC(0.01, 1), 1000); Turing.setadbackend(:zygote) sample(wishart(), HMC(0.01, 1), 1000); @@ -112,8 +111,8 @@ end Turing.setadbackend(:tracker) sample(invwishart(), HMC(0.01, 1), 1000); - #Turing.setadbackend(:reversediff) - #sample(invwishart(), HMC(0.01, 1), 1000); + Turing.setadbackend(:reversediff) + sample(invwishart(), HMC(0.01, 1), 1000); Turing.setadbackend(:zygote) sample(invwishart(), HMC(0.01, 1), 1000); end diff --git a/test/modes/OptimInterface.jl b/test/modes/OptimInterface.jl index ea873ffee..2418037a4 100644 --- a/test/modes/OptimInterface.jl +++ b/test/modes/OptimInterface.jl @@ -1,38 +1,3 @@ -# TODO: Remove these once the equivalent is present in `DynamicPPL.TestUtils. -function likelihood_optima(::DynamicPPL.TestUtils.UnivariateAssumeDemoModels) - return (s=1/16, m=7/4) -end -function posterior_optima(::DynamicPPL.TestUtils.UnivariateAssumeDemoModels) - # TODO: Figure out exact for `s`. - return (s=0.907407, m=7/6) -end - -function likelihood_optima(model::DynamicPPL.TestUtils.MultivariateAssumeDemoModels) - # Get some containers to fill. - vals = Random.rand(model) - - # NOTE: These are "as close to zero as we can get". - vals.s[1] = 1e-32 - vals.s[2] = 1e-32 - - vals.m[1] = 1.5 - vals.m[2] = 2.0 - - return vals -end -function posterior_optima(model::DynamicPPL.TestUtils.MultivariateAssumeDemoModels) - # Get some containers to fill. - vals = Random.rand(model) - - # TODO: Figure out exact for `s[1]`. - vals.s[1] = 0.890625 - vals.s[2] = 1 - vals.m[1] = 3/4 - vals.m[2] = 1 - - return vals -end - # Used for testing how well it works with nested contexts. struct OverrideContext{C,T1,T2} <: DynamicPPL.AbstractContext context::C @@ -57,7 +22,7 @@ function DynamicPPL.tilde_observe(context::OverrideContext, right, left, vi) return context.loglikelihood_weight, vi end -@testset "OptimInterface.jl" begin +@numerical_testset "OptimInterface.jl" begin @testset "MLE" begin Random.seed!(222) true_value = [0.0625, 1.75] @@ -157,7 +122,7 @@ end # FIXME: Some models doesn't work for Tracker and ReverseDiff. if Turing.Essential.ADBACKEND[] === :forwarddiff @testset "MAP for $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - result_true = posterior_optima(model) + result_true = DynamicPPL.TestUtils.posterior_optima(model) @testset "$(nameof(typeof(optimizer)))" for optimizer in [LBFGS(), NelderMead()] result = optimize(model, MAP(), optimizer) @@ -188,7 +153,7 @@ end DynamicPPL.TestUtils.demo_dot_assume_matrix_dot_observe_matrix, ] @testset "MLE for $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - result_true = likelihood_optima(model) + result_true = DynamicPPL.TestUtils.likelihood_optima(model) # `NelderMead` seems to struggle with convergence here, so we exclude it. @testset "$(nameof(typeof(optimizer)))" for optimizer in [LBFGS(),] From 017640704492c3a99e85ef968e4a33cb184a39cd Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Sun, 30 Jul 2023 16:30:00 +0100 Subject: [PATCH 12/28] Bugfixes. --- src/contrib/inference/abstractmcmc.jl | 8 +++++++- src/inference/Inference.jl | 9 +++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/contrib/inference/abstractmcmc.jl b/src/contrib/inference/abstractmcmc.jl index 0f158adfc..8ae1719da 100644 --- a/src/contrib/inference/abstractmcmc.jl +++ b/src/contrib/inference/abstractmcmc.jl @@ -12,6 +12,12 @@ function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition) return Transition(varinfo, transition) end +function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple) + set_namedtuple!(vi, θ) + vi +end +DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple) = SimpleVarInfo(θ, vi.logp, vi.transformation) + # NOTE: Only thing that depends on the underlying sampler. # Something similar should be part of AbstractMCMC at some point: # https://github.com/TuringLang/AbstractMCMC.jl/pull/86 @@ -37,7 +43,7 @@ function AbstractMCMC.step( # Create a log-density function with an implementation of the # gradient so we ensure that we're using the same AD backend as in Turing. - f = LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(model)) + f = LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(model, SimpleVarInfo(model))) # Link the varinfo. f = setvarinfo(f, DynamicPPL.link!!(getvarinfo(f), model)) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 4ad730d16..a7c6bb965 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -102,6 +102,15 @@ Wrap a sampler so it can be used as an inference algorithm. """ externalsampler(sampler::AbstractSampler) = ExternalSampler(sampler) +""" + ESLogDensityFunction + +A log density function for the External sampler. + +""" +const ESLogDensityFunction{M<:Model,S<:Sampler{<:ExternalSampler},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,<:DynamicPPL.DefaultContext} +LogDensityProblems.logdensity(f::ESLogDensityFunction, x::NamedTuple) = DynamicPPL.logjoint(f.model, SimpleVarInfo(x)) + # Algorithm for sampling from the prior struct Prior <: InferenceAlgorithm end From f775f523f8dcaeecfc0d70ea4d42b9d03a7f84c6 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Sun, 30 Jul 2023 16:30:44 +0100 Subject: [PATCH 13/28] Add TODO. --- src/contrib/inference/abstractmcmc.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/contrib/inference/abstractmcmc.jl b/src/contrib/inference/abstractmcmc.jl index 8ae1719da..d1d81735f 100644 --- a/src/contrib/inference/abstractmcmc.jl +++ b/src/contrib/inference/abstractmcmc.jl @@ -12,6 +12,7 @@ function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition) return Transition(varinfo, transition) end +# TODO: move these functions to DynamicPPL function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple) set_namedtuple!(vi, θ) vi From fb5612fa684766eefaec49e5e8bb700e94a179fd Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Sun, 30 Jul 2023 16:35:20 +0100 Subject: [PATCH 14/28] Update mh.jl --- test/inference/mh.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/inference/mh.jl b/test/inference/mh.jl index eee86d4a0..86fd374a2 100644 --- a/test/inference/mh.jl +++ b/test/inference/mh.jl @@ -18,10 +18,10 @@ s4 = Gibbs(MH(:m), MH(:s)) c4 = sample(gdemo_default, s4, N) - s5 = externalsampler(RWMH(gdemo_default)) + s5 = externalsampler(MH(gdemo_default, proposal_type=AMH.RandomWalkProposal)) c5 = sample(gdemo_default, s5, N) - s6 = externalsampler(StaticMH(gdemo_default)) + s6 = externalsampler(MH(gdemo_default, proposal_type=AMH.StaticProposal) c6 = sample(gdemo_default, s6, N) end @numerical_testset "mh inference" begin From 0abebd783438dbca61b26ddd2778b2ab44d4117b Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Sun, 30 Jul 2023 16:36:24 +0100 Subject: [PATCH 15/28] Update Inference.jl --- src/inference/Inference.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index a7c6bb965..05d472422 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -46,8 +46,6 @@ export InferenceAlgorithm, SampleFromUniform, SampleFromPrior, MH, - RWMH, - StaticMH, ESS, Emcee, Gibbs, # classic sampling From 3dae98133f8f2392aded9a35b4dff771674e150b Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Sun, 30 Jul 2023 16:44:04 +0100 Subject: [PATCH 16/28] Removed obsolete exports. --- src/Turing.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Turing.jl b/src/Turing.jl index c0b2279c2..33286a665 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -73,8 +73,6 @@ export @model, # modelling Prior, # Sampling from the prior MH, # classic sampling - RWMH, - StaticMH, Emcee, ESS, Gibbs, From f792d7315b81380e5b8b2f7dd4b291e6e3faefdb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 Jul 2023 10:16:24 +0100 Subject: [PATCH 17/28] removed unnecessary import of extract_priors --- src/inference/Inference.jl | 1 - src/inference/mh.jl | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 05d472422..e6d2e0596 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -30,7 +30,6 @@ import AdvancedMH; const AMH = AdvancedMH import AdvancedPS import BangBang import ..Essential: getADbackend -import DynamicPPL: extract_priors import EllipticalSliceSampling import LogDensityProblems import LogDensityProblemsAD diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 74f8e6eed..a32b8e6a9 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -189,7 +189,7 @@ function MH(space...) end function MH(model::Model; proposal_type=AMH.StaticProposal) - priors = extract_priors(model) + priors = DynamicPPL.extract_priors(model) props = Tuple([proposal_type(prop) for prop in values(priors)]) vars = Symbol.(keys(priors)) priors = NamedTuple{Tuple(vars)}(props) From b19abdc568b797903186611c7ea35f635503cf01 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 Jul 2023 10:16:32 +0100 Subject: [PATCH 18/28] added missing ) in MH tests --- test/inference/mh.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/inference/mh.jl b/test/inference/mh.jl index 86fd374a2..ae8bcf79e 100644 --- a/test/inference/mh.jl +++ b/test/inference/mh.jl @@ -21,7 +21,7 @@ s5 = externalsampler(MH(gdemo_default, proposal_type=AMH.RandomWalkProposal)) c5 = sample(gdemo_default, s5, N) - s6 = externalsampler(MH(gdemo_default, proposal_type=AMH.StaticProposal) + s6 = externalsampler(MH(gdemo_default, proposal_type=AMH.StaticProposal)) c6 = sample(gdemo_default, s6, N) end @numerical_testset "mh inference" begin From 7f42b53cb01e0d79e698fe0dc899781bac104724 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 Jul 2023 10:18:04 +0100 Subject: [PATCH 19/28] fixed incorrect referneces to AdvancedMH in tests --- test/inference/mh.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/inference/mh.jl b/test/inference/mh.jl index ae8bcf79e..94f9aa992 100644 --- a/test/inference/mh.jl +++ b/test/inference/mh.jl @@ -18,10 +18,10 @@ s4 = Gibbs(MH(:m), MH(:s)) c4 = sample(gdemo_default, s4, N) - s5 = externalsampler(MH(gdemo_default, proposal_type=AMH.RandomWalkProposal)) + s5 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.RandomWalkProposal)) c5 = sample(gdemo_default, s5, N) - s6 = externalsampler(MH(gdemo_default, proposal_type=AMH.StaticProposal)) + s6 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.StaticProposal)) c6 = sample(gdemo_default, s6, N) end @numerical_testset "mh inference" begin From 411d0301f7a00a16e532ba1757bb063283881684 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 Jul 2023 13:32:16 +0100 Subject: [PATCH 20/28] improve ESLogDensityFunction --- src/inference/Inference.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index e6d2e0596..c6b652cb9 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -106,7 +106,9 @@ A log density function for the External sampler. """ const ESLogDensityFunction{M<:Model,S<:Sampler{<:ExternalSampler},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,<:DynamicPPL.DefaultContext} -LogDensityProblems.logdensity(f::ESLogDensityFunction, x::NamedTuple) = DynamicPPL.logjoint(f.model, SimpleVarInfo(x)) +function LogDensityProblems.logdensity(f::ESLogDensityFunction, x::NamedTuple) + return DynamicPPL.logjoint(f.model, DynamicPPL.unflatten(f.varinfo, x)) +end # Algorithm for sampling from the prior struct Prior <: InferenceAlgorithm end From 7dd122a3f0be797c73d9d5571d7802e2657896a6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 Jul 2023 13:32:36 +0100 Subject: [PATCH 21/28] remove hardcoding of SimpleVarInfo --- src/contrib/inference/abstractmcmc.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/contrib/inference/abstractmcmc.jl b/src/contrib/inference/abstractmcmc.jl index d1d81735f..926358f3a 100644 --- a/src/contrib/inference/abstractmcmc.jl +++ b/src/contrib/inference/abstractmcmc.jl @@ -14,8 +14,8 @@ end # TODO: move these functions to DynamicPPL function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple) - set_namedtuple!(vi, θ) - vi + set_namedtuple!(deepcopy(vi), θ) + return vi end DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple) = SimpleVarInfo(θ, vi.logp, vi.transformation) @@ -44,7 +44,7 @@ function AbstractMCMC.step( # Create a log-density function with an implementation of the # gradient so we ensure that we're using the same AD backend as in Turing. - f = LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(model, SimpleVarInfo(model))) + f = LogDensityProblemsAD.ADgradient(DynamicPPL.LogDensityFunction(model)) # Link the varinfo. f = setvarinfo(f, DynamicPPL.link!!(getvarinfo(f), model)) From 5f0963e4b3b6abe8b44ec31b56cf0f6c230df684 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 Jul 2023 13:32:47 +0100 Subject: [PATCH 22/28] added fixme comment --- src/inference/mh.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/inference/mh.jl b/src/inference/mh.jl index a32b8e6a9..43eb1db0e 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -354,6 +354,7 @@ end function should_link(varinfo, sampler, proposal::AdvancedMH.RandomWalkProposal) return true end +# FIXME: This won't be hit unless `vals` are all the exactly same concrete type of `AdvancedMH.RandomWalkProposal`! function should_link( varinfo, sampler, From 70ddc238d1b9410b783c4df9e9308d6257abadec Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 Jul 2023 13:33:04 +0100 Subject: [PATCH 23/28] minor style changes --- src/inference/mh.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 43eb1db0e..a9a34b5fe 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -191,8 +191,8 @@ end function MH(model::Model; proposal_type=AMH.StaticProposal) priors = DynamicPPL.extract_priors(model) props = Tuple([proposal_type(prop) for prop in values(priors)]) - vars = Symbol.(keys(priors)) - priors = NamedTuple{Tuple(vars)}(props) + vars = Tuple(map(Symbol, collect(keys(priors)))) + priors = NamedTuple{vars}(props) return AMH.MetropolisHastings(priors) end From 3fcb839471ba679a887b1b8212d3e7655a7eb6cb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 Jul 2023 20:04:38 +0100 Subject: [PATCH 24/28] fixed issues with MH with RandomWalkProposal being used as an external sampler --- src/inference/mh.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/inference/mh.jl b/src/inference/mh.jl index a9a34b5fe..12faa920f 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -188,11 +188,17 @@ function MH(space...) return MH{tuple(syms...), typeof(proposals)}(proposals) end -function MH(model::Model; proposal_type=AMH.StaticProposal) +# Some of the proposals require working in unconstrained space. +transform_maybe(proposal::AMH.Proposal) = proposal +function transform_maybe(proposal::AMH.RandomWalkProposal) + return AMH.RandomWalkProposal(Bijectors.transformed(proposal.proposal)) +end + +function MH(model::Model; proposal_type=AMH.StatoicProposal) priors = DynamicPPL.extract_priors(model) props = Tuple([proposal_type(prop) for prop in values(priors)]) vars = Tuple(map(Symbol, collect(keys(priors)))) - priors = NamedTuple{vars}(props) + priors = map(transform_maybe, NamedTuple{vars}(props)) return AMH.MetropolisHastings(priors) end From 47584f841847869914752a9e6d6be3545d4ed611 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 Jul 2023 20:05:59 +0100 Subject: [PATCH 25/28] fixed accidental typo --- src/inference/mh.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 12faa920f..dd97efd18 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -194,7 +194,7 @@ function transform_maybe(proposal::AMH.RandomWalkProposal) return AMH.RandomWalkProposal(Bijectors.transformed(proposal.proposal)) end -function MH(model::Model; proposal_type=AMH.StatoicProposal) +function MH(model::Model; proposal_type=AMH.StaticProposal) priors = DynamicPPL.extract_priors(model) props = Tuple([proposal_type(prop) for prop in values(priors)]) vars = Tuple(map(Symbol, collect(keys(priors)))) From bdc87b9201fe2dbf3ad2f3e8b8fbbc2d9935f62b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 Jul 2023 20:27:02 +0100 Subject: [PATCH 26/28] move definitions of unflatten for NamedTuple --- src/contrib/inference/abstractmcmc.jl | 7 ------- src/inference/Inference.jl | 7 +++++++ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/contrib/inference/abstractmcmc.jl b/src/contrib/inference/abstractmcmc.jl index 926358f3a..0f158adfc 100644 --- a/src/contrib/inference/abstractmcmc.jl +++ b/src/contrib/inference/abstractmcmc.jl @@ -12,13 +12,6 @@ function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition) return Transition(varinfo, transition) end -# TODO: move these functions to DynamicPPL -function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple) - set_namedtuple!(deepcopy(vi), θ) - return vi -end -DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple) = SimpleVarInfo(θ, vi.logp, vi.transformation) - # NOTE: Only thing that depends on the underlying sampler. # Something similar should be part of AbstractMCMC at some point: # https://github.com/TuringLang/AbstractMCMC.jl/pull/86 diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index c6b652cb9..0eb1a18d1 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -110,6 +110,13 @@ function LogDensityProblems.logdensity(f::ESLogDensityFunction, x::NamedTuple) return DynamicPPL.logjoint(f.model, DynamicPPL.unflatten(f.varinfo, x)) end +# TODO: move these functions to DynamicPPL +function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple) + set_namedtuple!(deepcopy(vi), θ) + return vi +end +DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple) = SimpleVarInfo(θ, vi.logp, vi.transformation) + # Algorithm for sampling from the prior struct Prior <: InferenceAlgorithm end From ef27088df8a46643950b8bc6dac76b1db6e8fe54 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 Jul 2023 20:27:32 +0100 Subject: [PATCH 27/28] improved TODO --- src/inference/Inference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 0eb1a18d1..45ae434a4 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -110,7 +110,7 @@ function LogDensityProblems.logdensity(f::ESLogDensityFunction, x::NamedTuple) return DynamicPPL.logjoint(f.model, DynamicPPL.unflatten(f.varinfo, x)) end -# TODO: move these functions to DynamicPPL +# TODO: make a nicer `set_namedtuple!` and move these functions to DynamicPPL. function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple) set_namedtuple!(deepcopy(vi), θ) return vi From 045e2d0c784658550ddbd8746d07181a9cdf29d7 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 16 Aug 2023 16:19:18 +0100 Subject: [PATCH 28/28] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c037ee8dc..e4a335fa4 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ AdvancedMH = "0.6.8, 0.7" AdvancedPS = "0.4" AdvancedVI = "0.2" BangBang = "0.3" -Bijectors = "0.13.2" +Bijectors = "0.13.5" DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6"