diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a5d178125..dc27031fe 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -125,6 +125,7 @@ export AbstractVarInfo, # Convenience macros @addlogprob!, @submodel, + @returned_quantities, value_iterator_from_chain, check_model, check_model_and_trace, diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 050bf31fc..b6c92078b 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -248,3 +248,185 @@ function submodel(prefix_expr, expr, ctx=esc(:__context__)) end end end + +""" + @returned_quantities [prefix=...] model + +Run `model` nested inside of another model and return the return-values of the `model`. + +Valid expressions for `prefix=...` are: +- `prefix=false`: no prefix is used. This is the default. +- `prefix=expression`: results in the prefix `Symbol(expression)`. + +Prefixing makes it possible to run the same model multiple times while keeping track of +all random variables correctly, i.e. without name clashes. + +# Examples + +## Simple example +```jldoctest submodel-returned-quantities; setup=:(using Distributions) +julia> @model function demo1(x) + x ~ Normal() + return 1 + abs(x) + end; + +julia> @model function demo2(x, y) + a = @returned_quantities(demo1(x)) + return y ~ Uniform(0, a) + end; +``` + +When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: +```jldoctest submodel-return-quantities +julia> vi = VarInfo(demo2(missing, 0.4)); + +julia> @varname(x) in keys(vi) +true +``` + +Variable `a` is not tracked since it can be computed from the random variable `x` that was +tracked when running `demo1`: +```jldoctest submodel-returned-quantities +julia> @varname(a) in keys(vi) +false +``` + +We can check that the log joint probability of the model accumulated in `vi` is correct: + +```jldoctest submodel-return-quantities +julia> x = vi[@varname(x)]; + +julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) +true +``` + +## With prefixing +```jldoctest submodel-return-quantities-prefix; setup=:(using Distributions) +julia> @model function demo1(x) + x ~ Normal() + return 1 + abs(x) + end; + +julia> @model function demo2(x, y, z) + a = @returned_quantities prefix="sub1" demo1(x) + b = @returned_quantities prefix="sub2" demo1(y) + return z ~ Uniform(-a, b) + end; +``` + +When we sample from the model `demo2(missing, missing, 0.4)` random variables `sub1.x` and +`sub2.x` will be sampled: +```jldoctest submodel-return-quantities-prefix +julia> vi = VarInfo(demo2(missing, missing, 0.4)); + +julia> @varname(var"sub1.x") in keys(vi) +true + +julia> @varname(var"sub2.x") in keys(vi) +true +``` + +Variables `a` and `b` are not tracked since they can be computed from the random variables `sub1.x` and +`sub2.x` that were tracked when running `demo1`: +```jldoctest submodel-return-quantities-prefix +julia> @varname(a) in keys(vi) +false + +julia> @varname(b) in keys(vi) +false +``` + +We can check that the log joint probability of the model accumulated in `vi` is correct: + +```jldoctest submodel-return-quantities-prefix +julia> sub1_x = vi[@varname(var"sub1.x")]; + +julia> sub2_x = vi[@varname(var"sub2.x")]; + +julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); + +julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); + +julia> getlogp(vi) ≈ logprior + loglikelihood +true +``` + +## Different ways of setting the prefix +```jldoctest submodel-return-quantities-prefix-alts; setup=:(using DynamicPPL, Distributions) +julia> @model inner() = x ~ Normal() +inner (generic function with 2 methods) + +julia> # When `prefix` is unspecified, no prefix is used. + @model submodel_noprefix() = a = @returned_quantities inner() +submodel_noprefix (generic function with 2 methods) + +julia> @varname(x) in keys(VarInfo(submodel_noprefix())) +true + +julia> # Explicitely don't use any prefix. + @model submodel_prefix_false() = a = @returned_quantities prefix=false inner() +submodel_prefix_false (generic function with 2 methods) + +julia> @varname(x) in keys(VarInfo(submodel_prefix_false())) +true + +julia> # Using a static string. + @model submodel_prefix_string() = a = @returned_quantities prefix="my prefix" inner() +submodel_prefix_string (generic function with 2 methods) + +julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string())) +true + +julia> # Using string interpolation. + @model submodel_prefix_interpolation() = a = @returned_quantities prefix="\$(nameof(inner()))" inner() +submodel_prefix_interpolation (generic function with 2 methods) + +julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation())) +true + +julia> # Or using some arbitrary expression. + @model submodel_prefix_expr() = a = @returned_quantities prefix=1 + 2 inner() +submodel_prefix_expr (generic function with 2 methods) + +julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr())) +true +``` +""" +macro returned_quantities(expr) + return returned_quantities_expr(:(prefix = false), expr) +end + +macro returned_quantities(prefix_expr, expr) + return returned_quantities_expr(prefix_expr, expr) +end + +""" + @returned_quantities_expr model + +Returns an expression that captures the return-values of a model in addition to the varinfo. +""" +function returned_quantities_expr(prefix_expr, expr, ctx=esc(:__context__)) + prefix_left, prefix = getargs_assignment(prefix_expr) + if prefix_left !== :prefix + error("$(prefix_left) is not a valid kwarg") + end + + # The user expects `@submodel ...` to return the + # return-value of the `...`, hence we need to capture + # the return-value and handle it correctly. + @gensym retval + + # Prefix. + if prefix !== nothing + ctx = prefix_submodel_context(prefix, ctx) + end + return quote + # Evaluate the model and capture the return values + varinfo. + $retval, $(esc(:__varinfo__)) = $(_evaluate!!)( + $(esc(expr)), $(esc(:__varinfo__)), $(ctx) + ) + + # Return the return-value of the model. + $retval + end +end