From 001dd6ff5411c2be48de78f55b030384eabcd276 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 18 Jun 2024 11:00:40 +0100 Subject: [PATCH 1/4] Fix docstring typo --- src/contexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 63f624b4e..dd0690e69 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -188,7 +188,7 @@ getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(conte """ struct DefaultContext <: AbstractContext end -The `DefaultContext` is used by default to compute log the joint probability of the data +The `DefaultContext` is used by default to compute the log joint probability of the data and parameters when running the model. """ struct DefaultContext <: AbstractContext end From 1c81bd67906b64a423e3fb68fa29c5cfdb390cd0 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 18 Jun 2024 11:01:22 +0100 Subject: [PATCH 2/4] Add mention of context in the docstring of Model --- src/model.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/model.jl b/src/model.jl index 09c0c1be1..98bf527a4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,15 +1,17 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} + struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstactContext} f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} + context::Ctx end A `Model` struct with model evaluation function of type `F`, arguments of names `argnames` -types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, and missing -arguments `missings`. +types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, missing +arguments `missings`, and evaluation context of type `Ctx`. Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`. +`context` is by default `DefaultContext()`. An argument with a type of `Missing` will be in `missings` by default. However, in non-traditional use-cases `missings` can be defined differently. All variables in `missings` From 1a4eadb69764b0eb234ab4f04b10cfc412c4b75a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 18 Jun 2024 11:28:50 +0100 Subject: [PATCH 3/4] Add a docstring for DynamicTransformationContext --- src/transforming.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/transforming.jl b/src/transforming.jl index 41c877c91..1f6c55e24 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -1,3 +1,14 @@ +""" + struct DynamicTransformationContext{isinverse} <: AbstractContext + +When a model is evaluated with this context, transform the accompanying `AbstractVarInfo` to +constrained space if `isinverse` or unconstrained if `!isinverse`. + +Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the +`DynamicTransformationContext` methods with more efficient implementations. +`DynamicTransformationContext` is a fallback for when we need to evaluate the model to know +how to do the transformation, used by e.g. `SimpleVarInfo`. +""" struct DynamicTransformationContext{isinverse} <: AbstractContext end NodeTrait(::DynamicTransformationContext) = IsLeaf() From 5c2a625cc5574dcf2c79ded4c24fb8ee1ddbf63b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 2 Sep 2024 16:38:02 +0100 Subject: [PATCH 4/4] Tiny style improvements --- src/contexts.jl | 16 ++++++++-------- src/model.jl | 14 +++++++------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index dd0690e69..53b454df6 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -188,7 +188,7 @@ getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(conte """ struct DefaultContext <: AbstractContext end -The `DefaultContext` is used by default to compute the log joint probability of the data +The `DefaultContext` is used by default to compute the log joint probability of the data and parameters when running the model. """ struct DefaultContext <: AbstractContext end @@ -199,7 +199,7 @@ NodeTrait(context::DefaultContext) = IsLeaf() vars::Tvars end -The `PriorContext` enables the computation of the log prior of the parameters `vars` when +The `PriorContext` enables the computation of the log prior of the parameters `vars` when running the model. """ struct PriorContext{Tvars} <: AbstractContext @@ -213,8 +213,8 @@ NodeTrait(context::PriorContext) = IsLeaf() vars::Tvars end -The `LikelihoodContext` enables the computation of the log likelihood of the parameters when -running the model. `vars` can be used to evaluate the log likelihood for specific values +The `LikelihoodContext` enables the computation of the log likelihood of the parameters when +running the model. `vars` can be used to evaluate the log likelihood for specific values of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. """ struct LikelihoodContext{Tvars} <: AbstractContext @@ -229,10 +229,10 @@ NodeTrait(context::LikelihoodContext) = IsLeaf() loglike_scalar::T end -The `MiniBatchContext` enables the computation of -`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the -`loglike_scalar` field, typically equal to `the number of data points / batch size`. -This is useful in batch-based stochastic gradient descent algorithms to be optimizing +The `MiniBatchContext` enables the computation of +`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the +`loglike_scalar` field, typically equal to `the number of data points / batch size`. +This is useful in batch-based stochastic gradient descent algorithms to be optimizing `log(prior) + log(likelihood of all the data points)` in the expectation. """ struct MiniBatchContext{Tctx,T} <: AbstractContext diff --git a/src/model.jl b/src/model.jl index 98bf527a4..082ec3871 100644 --- a/src/model.jl +++ b/src/model.jl @@ -3,7 +3,7 @@ f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} - context::Ctx + context::Ctx=DefaultContext() end A `Model` struct with model evaluation function of type `F`, arguments of names `argnames` @@ -1079,7 +1079,7 @@ end Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`. # Examples - + ```jldoctest julia> using MCMCChains, Distributions @@ -1095,7 +1095,7 @@ julia> # construct a chain of samples using MCMCChains chain = Chains(rand(10, 2, 3), [:s, :m]); julia> logjoint(demo_model([1., 2.]), chain); -``` +``` """ function logjoint(model::Model, chain::AbstractMCMC.AbstractChains) var_info = VarInfo(model) # extract variables info from the model @@ -1126,7 +1126,7 @@ end Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`. # Examples - + ```jldoctest julia> using MCMCChains, Distributions @@ -1142,7 +1142,7 @@ julia> # construct a chain of samples using MCMCChains chain = Chains(rand(10, 2, 3), [:s, :m]); julia> logprior(demo_model([1., 2.]), chain); -``` +``` """ function logprior(model::Model, chain::AbstractMCMC.AbstractChains) var_info = VarInfo(model) # extract variables info from the model @@ -1173,7 +1173,7 @@ end Return an array of log likelihoods evaluated at each sample in an MCMC `chain`. # Examples - + ```jldoctest julia> using MCMCChains, Distributions @@ -1189,7 +1189,7 @@ julia> # construct a chain of samples using MCMCChains chain = Chains(rand(10, 2, 3), [:s, :m]); julia> loglikelihood(demo_model([1., 2.]), chain); -``` +``` """ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) var_info = VarInfo(model) # extract variables info from the model