From f45b252a848e47930f4fb32bf202438af6621c0f Mon Sep 17 00:00:00 2001 From: Carlos Parada Date: Mon, 19 Dec 2022 05:02:28 +0000 Subject: [PATCH] Probability interface tutorial (#404) First addition to the DynamicPPL tutorials; breaking this up as Hong suggested. Goes over how to use the basic interfaces (e.g. logjoint, loglikelihood, logdensityof). Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: David Widmann --- docs/Project.toml | 2 + docs/make.jl | 6 +- docs/src/assets/{dynamicppl.svg => logo.svg} | 0 docs/src/tutorials/prob-interface.md | 98 ++++++++++++++++++++ 4 files changed, 105 insertions(+), 1 deletion(-) rename docs/src/assets/{dynamicppl.svg => logo.svg} (100%) create mode 100644 docs/src/tutorials/prob-interface.md diff --git a/docs/Project.toml b/docs/Project.toml index 6e286b0ce..225e2f49e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,11 +1,13 @@ [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] Distributions = "0.25" Documenter = "0.27" +FillArrays = "0.13" Setfield = "0.7.1, 0.8, 1" StableRNGs = "1" diff --git a/docs/make.jl b/docs/make.jl index 6b88c18cd..cdf236655 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -11,7 +11,11 @@ makedocs(; sitename="DynamicPPL", format=Documenter.HTML(), modules=[DynamicPPL], - pages=["Home" => "index.md", "API" => "api.md"], + pages=[ + "Home" => "index.md", + "API" => "api.md", + "Tutorials" => ["tutorials/prob-interface.md"], + ], strict=true, checkdocs=:exports, ) diff --git a/docs/src/assets/dynamicppl.svg b/docs/src/assets/logo.svg similarity index 100% rename from docs/src/assets/dynamicppl.svg rename to docs/src/assets/logo.svg diff --git a/docs/src/tutorials/prob-interface.md b/docs/src/tutorials/prob-interface.md new file mode 100644 index 000000000..0a8a4dcc0 --- /dev/null +++ b/docs/src/tutorials/prob-interface.md @@ -0,0 +1,98 @@ +# The Probability Interface + +The easiest way to manipulate and query DynamicPPL models is via the DynamicPPL probability +interface. + +Let's use a simple model of normally-distributed data as an example. +```@example probinterface +using DynamicPPL +using Distributions +using FillArrays +using LinearAlgebra +using Random + +Random.seed!(1776) # Set seed for reproducibility + +@model function gdemo(n) + μ ~ Normal(0, 1) + x ~ MvNormal(Fill(μ, n), I) + return nothing +end +nothing # hide +``` + +We generate some data using `μ = 0` and `σ = 1`: + +```@example probinterface +dataset = randn(100) +nothing # hide +``` + +## Conditioning and Deconditioning + +Bayesian models can be transformed with two main operations, conditioning and deconditioning (also known as marginalization). +Conditioning takes a variable and fixes its value as known. +We do this by passing a model and a named tuple of conditioned variables to `|`: +```@example probinterface +model = gdemo(length(dataset)) | (x=dataset, μ=0, σ=1) +nothing # hide +``` + +This operation can be reversed by applying `decondition`: +```@example probinterface +decondition(model) +nothing # hide +``` + +We can also decondition only some of the variables: +```@example probinterface +decondition(model, :μ) +nothing # hide +``` + +## Probabilities and Densities + +We often want to calculate the (unnormalized) probability density for an event. +This probability might be a prior, a likelihood, or a posterior (joint) density. +DynamicPPL provides convenient functions for this. +For example, if we wanted to calculate the probability of a draw from the prior: +```@example probinterface +model = gdemo(length(dataset)) | (x=dataset,) +x1 = rand(model) +logjoint(model, x1) +``` + +For convenience, we provide the functions `loglikelihood` and `logjoint` to calculate probabilities for a named tuple, given a model: +```@example probinterface +@assert logjoint(model, x1) ≈ loglikelihood(model, x1) + logprior(model, x1) +``` + +## Example: Cross-validation + +To give an example of the probability interface in use, we can use it to estimate the performance of our model using cross-validation. In cross-validation, we split the dataset into several equal parts. Then, we choose one of these sets to serve as the validation set. Here, we measure fit using the cross entropy (Bayes loss).¹ +``` @example probinterface +function cross_val(model, dataset) + training_loss = zero(logjoint(model, rand(model))) + + # Partition our dataset into 5 folds with 20 observations: + test_folds = collect(Iterators.partition(dataset, 20)) + train_folds = setdiff.((dataset,), test_folds) + + for (train, test) in zip(train_folds, test_folds) + # First, we train the model on the training set. + # For normally-distributed data, the posterior can be solved in closed form: + posterior = Normal(mean(train), 1) + # Sample from the posterior + samples = NamedTuple{(:μ,)}.(rand(posterior, 1000)) + # Test + testing_model = gdemo(length(test)) | (x = test,) + training_loss += sum(samples) do sample + logjoint(testing_model, sample) + end + end + return training_loss +end +cross_val(model, dataset) +``` + +¹See [ParetoSmooth.jl](https://github.com/TuringLang/ParetoSmooth.jl) for a faster and more accurate implementation of cross-validation than the one provided here.