Skip to content

Commit

Permalink
Added @returned_quantities macro
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Oct 23, 2024
1 parent 54691bf commit 5c746c4
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ export AbstractVarInfo,
# Convenience macros
@addlogprob!,
@submodel,
@returned_quantities,
value_iterator_from_chain,
check_model,
check_model_and_trace,
Expand Down
182 changes: 182 additions & 0 deletions src/submodel_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,185 @@ function submodel(prefix_expr, expr, ctx=esc(:__context__))
end
end
end

"""
@returned_quantities [prefix=...] model
Run `model` nested inside of another model and return the return-values of the `model`.
Valid expressions for `prefix=...` are:
- `prefix=false`: no prefix is used. This is the default.
- `prefix=expression`: results in the prefix `Symbol(expression)`.
Prefixing makes it possible to run the same model multiple times while keeping track of
all random variables correctly, i.e. without name clashes.
# Examples
## Simple example
```jldoctest submodel-returned-quantities; setup=:(using Distributions)
julia> @model function demo1(x)
x ~ Normal()
return 1 + abs(x)
end;
julia> @model function demo2(x, y)
a = @returned_quantities(demo1(x))
return y ~ Uniform(0, a)
end;
```
When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled:
```jldoctest submodel-return-quantities
julia> vi = VarInfo(demo2(missing, 0.4));
julia> @varname(x) in keys(vi)
true
```
Variable `a` is not tracked since it can be computed from the random variable `x` that was
tracked when running `demo1`:
```jldoctest submodel-returned-quantities
julia> @varname(a) in keys(vi)
false
```
We can check that the log joint probability of the model accumulated in `vi` is correct:
```jldoctest submodel-return-quantities
julia> x = vi[@varname(x)];
julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4)
true
```
## With prefixing
```jldoctest submodel-return-quantities-prefix; setup=:(using Distributions)
julia> @model function demo1(x)
x ~ Normal()
return 1 + abs(x)
end;
julia> @model function demo2(x, y, z)
a = @returned_quantities prefix="sub1" demo1(x)
b = @returned_quantities prefix="sub2" demo1(y)
return z ~ Uniform(-a, b)
end;
```
When we sample from the model `demo2(missing, missing, 0.4)` random variables `sub1.x` and
`sub2.x` will be sampled:
```jldoctest submodel-return-quantities-prefix
julia> vi = VarInfo(demo2(missing, missing, 0.4));
julia> @varname(var"sub1.x") in keys(vi)
true
julia> @varname(var"sub2.x") in keys(vi)
true
```
Variables `a` and `b` are not tracked since they can be computed from the random variables `sub1.x` and
`sub2.x` that were tracked when running `demo1`:
```jldoctest submodel-return-quantities-prefix
julia> @varname(a) in keys(vi)
false
julia> @varname(b) in keys(vi)
false
```
We can check that the log joint probability of the model accumulated in `vi` is correct:
```jldoctest submodel-return-quantities-prefix
julia> sub1_x = vi[@varname(var"sub1.x")];
julia> sub2_x = vi[@varname(var"sub2.x")];
julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x);
julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4);
julia> getlogp(vi) ≈ logprior + loglikelihood
true
```
## Different ways of setting the prefix
```jldoctest submodel-return-quantities-prefix-alts; setup=:(using DynamicPPL, Distributions)
julia> @model inner() = x ~ Normal()
inner (generic function with 2 methods)
julia> # When `prefix` is unspecified, no prefix is used.
@model submodel_noprefix() = a = @returned_quantities inner()
submodel_noprefix (generic function with 2 methods)
julia> @varname(x) in keys(VarInfo(submodel_noprefix()))
true
julia> # Explicitely don't use any prefix.
@model submodel_prefix_false() = a = @returned_quantities prefix=false inner()
submodel_prefix_false (generic function with 2 methods)
julia> @varname(x) in keys(VarInfo(submodel_prefix_false()))
true
julia> # Using a static string.
@model submodel_prefix_string() = a = @returned_quantities prefix="my prefix" inner()
submodel_prefix_string (generic function with 2 methods)
julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string()))
true
julia> # Using string interpolation.
@model submodel_prefix_interpolation() = a = @returned_quantities prefix="\$(nameof(inner()))" inner()
submodel_prefix_interpolation (generic function with 2 methods)
julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation()))
true
julia> # Or using some arbitrary expression.
@model submodel_prefix_expr() = a = @returned_quantities prefix=1 + 2 inner()
submodel_prefix_expr (generic function with 2 methods)
julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr()))
true
```
"""
macro returned_quantities(expr)
return returned_quantities_expr(:(prefix = false), expr)
end

macro returned_quantities(prefix_expr, expr)
return returned_quantities_expr(prefix_expr, expr)
end

"""
@returned_quantities_expr model
Returns an expression that captures the return-values of a model in addition to the varinfo.
"""
function returned_quantities_expr(prefix_expr, expr, ctx=esc(:__context__))
prefix_left, prefix = getargs_assignment(prefix_expr)
if prefix_left !== :prefix
error("$(prefix_left) is not a valid kwarg")
end

# The user expects `@submodel ...` to return the
# return-value of the `...`, hence we need to capture
# the return-value and handle it correctly.
@gensym retval

# Prefix.
if prefix !== nothing
ctx = prefix_submodel_context(prefix, ctx)
end
return quote
# Evaluate the model and capture the return values + varinfo.
$retval, $(esc(:__varinfo__)) = $(_evaluate!!)(
$(esc(expr)), $(esc(:__varinfo__)), $(ctx)
)

# Return the return-value of the model.
$retval
end
end

0 comments on commit 5c746c4

Please sign in to comment.