From 5121cdfed67e7afa1d9ff410fa2cda3d8b9b8e78 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 17 Jun 2024 09:10:33 +0100 Subject: [PATCH 1/6] lazily resolve context to avoid overriding the model context --- src/logdensityfunction.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 007dfef11..d2d272bbe 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -49,7 +49,7 @@ struct LogDensityFunction{V,M,C} varinfo::V "model used for evaluation" model::M - "context used for evaluation" + "context used for evaluation; if `nothing`, `model.context` will be used when applicable" context::C end @@ -66,11 +66,14 @@ end function LogDensityFunction( model::Model, varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=model.context, + context::Union{Nothing,AbstractContext}=nothing, ) return LogDensityFunction(varinfo, model, context) end +# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`. +getcontext(f::LogDensityFunction) = f.context === nothing ? leafcontext(f.model) : f.context + # HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time # we need to define these annoying methods to ensure that we stay compatible with everything. getsampler(f::LogDensityFunction) = getsampler(f.context) @@ -90,8 +93,9 @@ getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(f.context)] # LogDensityProblems interface function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) - vi_new = unflatten(f.varinfo, f.context, θ) - return getlogp(last(evaluate!!(f.model, vi_new, f.context))) + context = getcontext(f) + vi_new = unflatten(f.varinfo, context, θ) + return getlogp(last(evaluate!!(f.model, vi_new, context))) end function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) return LogDensityProblems.LogDensityOrder{0}() From a353b4cd0c175a16dc5c43c1eaeb1dcfad75dda3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 17 Jun 2024 09:11:26 +0100 Subject: [PATCH 2/6] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a81112d74..6b131ef87 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.27.2" +version = "0.27.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From d1f5a7516d9a16c45027c49ca80734840625dc9f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 18 Jun 2024 08:02:52 +0100 Subject: [PATCH 3/6] Update src/logdensityfunction.jl --- src/logdensityfunction.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index d2d272bbe..7db7bed3c 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -72,7 +72,7 @@ function LogDensityFunction( end # If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`. -getcontext(f::LogDensityFunction) = f.context === nothing ? leafcontext(f.model) : f.context +getcontext(f::LogDensityFunction) = f.context === nothing ? leafcontext(f.model.context) : f.context # HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time # we need to define these annoying methods to ensure that we stay compatible with everything. From 92b2102bd33067072037813c15ef0feffc08a1fd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 18 Jun 2024 08:04:11 +0100 Subject: [PATCH 4/6] Update src/logdensityfunction.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/logdensityfunction.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 7db7bed3c..55a4a596a 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -72,7 +72,9 @@ function LogDensityFunction( end # If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`. -getcontext(f::LogDensityFunction) = f.context === nothing ? leafcontext(f.model.context) : f.context +function getcontext(f::LogDensityFunction) + return f.context === nothing ? leafcontext(f.model.context) : f.context +end # HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time # we need to define these annoying methods to ensure that we stay compatible with everything. From dfd8a8f328a6d441d732e66325396784b2a82f3f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 18 Jun 2024 08:33:20 +0100 Subject: [PATCH 5/6] replaces more references to `f.context` with `getcontext(f)` --- src/logdensityfunction.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 55a4a596a..8935edc12 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -49,7 +49,7 @@ struct LogDensityFunction{V,M,C} varinfo::V "model used for evaluation" model::M - "context used for evaluation; if `nothing`, `model.context` will be used when applicable" + "context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable" context::C end @@ -78,8 +78,8 @@ end # HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time # we need to define these annoying methods to ensure that we stay compatible with everything. -getsampler(f::LogDensityFunction) = getsampler(f.context) -hassampler(f::LogDensityFunction) = hassampler(f.context) +getsampler(f::LogDensityFunction) = getsampler(getcontext(f)) +hassampler(f::LogDensityFunction) = hassampler(getcontext(f)) _get_indexer(ctx::AbstractContext) = _get_indexer(NodeTrait(ctx), ctx) _get_indexer(ctx::SamplingContext) = ctx.sampler @@ -91,7 +91,7 @@ _get_indexer(::IsLeaf, ctx::AbstractContext) = Colon() Return the parameters of the wrapped varinfo as a vector. """ -getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(f.context)] +getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(getcontext(f))] # LogDensityProblems interface function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) From 8342f62785733419721bfaf331faeba729b99f40 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 25 Jun 2024 17:30:57 +0100 Subject: [PATCH 6/6] Bump version to v0.28 Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6b131ef87..78acb2566 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.27.3" +version = "0.28" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"