Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gibbsian polar slice sampler #4

Merged
merged 10 commits into from
May 26, 2024
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
name = "SliceSampling"
uuid = "43f4d3e8-9711-4a8c-bd1b-03ac73a255cf"
version = "0.2.1"
version = "0.3.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -23,6 +24,7 @@ AbstractMCMC = "4, 5"
Accessors = "0.1"
Distributions = "0.25"
FillArrays = "1"
LinearAlgebra = "1"
LogDensityProblems = "2"
Random = "1"
Requires = "1"
Expand Down
38 changes: 36 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,43 @@
# Implementation of slice sampling algorithms

[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://Red-Portal.github.io/SliceSampling.jl/stable/)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://Red-Portal.github.io/SliceSampling.jl/dev/)
[![Build Status](https://github.com/Red-Portal/SliceSampling.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/Red-Portal/SliceSampling.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/Red-Portal/SliceSampling.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/Red-Portal/SliceSampling.jl)

This package implements slice sampling algorithms accessible through the the `AbstractMCMC` [interface](https://github.com/TuringLang/AbstractMCMC.jl).
For general usage, please refer to [here](https://turinglang.org/SliceSampling.jl/dev/general/).

## Implemented Algorithms
- Univariate slice sampling algorithms with coordinate-wise Gibbs sampling by R. Neal [^N2003].
- Latent slice sampling by Li and Walker[^LW2023]
- Gibbsian polar slice sampling by P. Schär, M. Habeck, and D. Rudolf[^SHR2023].

## Example with Turing Models
This package supports the [Turing](https://github.com/TuringLang/Turing.jl) probabilistic programming framework:

```@example turing
using Distributions
using Turing
using SliceSampling

@model function demo()
s ~ InverseGamma(3, 3)
m ~ Normal(0, sqrt(s))
end

sampler = LatentSlice(2)
n_samples = 10000
model = demo()
sample(model, externalsampler(sampler), n_samples; initial_params=[1.0, 0.0])
```

[^N2003]: Neal, R. M. (2003). Slice sampling. The annals of statistics, 31(3), 705-767.
[^LW2023]: Li, Y., & Walker, S. G. (2023). A latent slice sampling algorithm. Computational Statistics & Data Analysis, 179, 107652.
[^SHR2023]: Schär, P., Habeck, M., & Rudolf, D. (2023, July). Gibbsian polar slice sampling. In International Conference on Machine Learning.
=======
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://TuringLang.org/SliceSampling.jl/stable/)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://TuringLang.org/SliceSampling.jl/dev/)
[![Build Status](https://github.com/TuringLang/SliceSampling.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/Red-Portal/SliceSampling.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/TuringLang/SliceSampling.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/Red-Portal/SliceSampling.jl)


For a working example, please see [here](https://turinglang.org/SliceSampling.jl/dev/general/).
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SliceSampling = "43f4d3e8-9711-4a8c-bd1b-03ac73a255cf"
Expand Down
9 changes: 5 additions & 4 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ makedocs(;
assets=String[],
),
pages=[
"Home" => "index.md",
"General Usage" => "general.md",
"Univariate Slice Sampling" => "univariate_slice.md",
"Latent Slice Sampling" => "latent_slice.md"
"Home" => "index.md",
"General Usage" => "general.md",
"Univariate Slice Sampling" => "univariate_slice.md",
"Latent Slice Sampling" => "latent_slice.md",
"Gibbsian Polar Slice Sampling" => "gibbs_polar.md"
],
)

Expand Down
76 changes: 76 additions & 0 deletions docs/src/gibbs_polar.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@

# [Gibbsian Polar Slice Sampling](@id polar)

## Introduction
Gibbsian polar slice sampling (GPSS) is a recent vector-valued slice sampling algorithm proposed by P. Schär, M. Habeck, and D. Rudolf[^SHR2023].
It is an computationally efficient variant of polar slice sampler previously proposed by Roberts and Rosenthal[^RR2002].
Unlike other slice sampling algorithms, it operates a Gibbs sampler over polar coordinates, reminiscent of the elliptical slice sampler (ESS).
Due to the involvement of polar coordinates, GPSS only works reliably on more than one dimension.
However, unlike ESS, GPSS is applicable to any target distribution.


## Description
For a $$d$$-dimensional target distribution, GPSS utilizes the following augmented target distribution:
```math
\begin{aligned}
p(x, T) &= \varrho_{\pi}^{(0)}(x) \varrho_{\pi}^{(1)}(x) \, \operatorname{Uniform}\left(T; 0, \varrho^1(x)\right) \\
\varrho_{\pi}^{(0)}(x) &= {\lVert x \rVert}^{1 - d} \\
\varrho_{\pi}^{(1)}(x) &= {\lVert x \rVert}^{d-1} \pi\left(x\right)
\end{aligned}
```
As described in Appendix A of the GPSS paper, sampling from $$\varrho^{(1)}(x)$$ in polar coordinates magically targets the augmented target distribution.

In a high-level view, GPSS operates a Gibbs sampler in the following fashion:
```math
\begin{aligned}
T_n &\sim \operatorname{Uniform}\left(0, \varrho^{(1)}\left(x_{n-1}\right)\right) \\
\theta_n &\sim \operatorname{Uniform}\left\{ \theta \in \mathbb{S}^{d-1} \mid \varrho^{(1)}\left(r_{n-1} \theta\right) > T_n \right\} \\
r_n &\sim \operatorname{Uniform}\left\{ r \in \mathbb{R}_{\geq 0} \mid \varrho^{(1)}\left(r \theta_n\right) > T_n \right\} \\
x &= \theta r,
\end{aligned}
```
where $$T_n$$ is the usual acceptance threshold auxiliary variable, while $$\theta$$ and $$r$$ are the sampler states in polar coordinates.
The Gibbs steps on $$\theta$$ and $$r$$ are implemented through specialized shrinkage procedures.

The only tunable parameter of the algorithm is the size of the search interval (window) of the shrinkage sampler for the radius variable $$r$$.

!!! info
Since the direction and radius variables are states of the Markov chain, this sampler is **not reversible** with respect to the samples of the log-target $$x$$.

## Interface

!!! warning
By the nature of polar coordinates, GPSS only works reliably for targets with dimension at least $$d \geq 2$$.

!!! warning
When initializing the chain (*e.g.* the `initial_params` keyword arguments in `AbstractMCMC.sample`), it is necessary to inialize from a point $$x_0$$ that has a sensible norm $$\lVert x_0 \rVert > 0$$, otherwise, the chain will start from a pathologic point in polar coordinates. If it is smaller than `1e-5`, the current implementation automatically sets the initial radius as `1e-5`.


```@docs
GibbsPolarSlice
```

## Demonstration
As illustrated in the original paper, GPSS shows good performance on heavy-tailed targets despite being a multivariate slice sampler:
```@example gpss
using Distributions
using Turing
using SliceSampling
using LinearAlgebra
using Plots

@model function demo()
x ~ MvTDist(1, zeros(10), Matrix(I,10,10))
end
model = demo()

n_samples = 10000
chain = sample(model, externalsampler(GibbsPolarSlice(10)), n_samples; initial_params=ones(10))
histogram(chain[:,1,:], xlims=[-10,10])
savefig("cauchy_gpss.svg")
```
![](cauchy_gpss.svg)


[^SHR2023]: Schär, P., Habeck, M., & Rudolf, D. (2023, July). Gibbsian polar slice sampling. In International Conference on Machine Learning.
[^RR2002]: Roberts, G. O., & Rosenthal, J. S. (2002). The polar slice sampler. Stochastic Models, 18(2), 257-280.
6 changes: 3 additions & 3 deletions docs/src/latent_slice.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ where $$y$$ is the parameters of the log-target $$\pi$$, $$s$$ is the width of t
Naturally, the sampler operates as a blocked-Gibbs sampler
```math
\begin{aligned}
l &\sim \operatorname{Uniform}\left(l; \; y - s/2,\, y + s/2\right) \\
s &\sim p(s \mid y, l) \\
y &\sim \operatorname{slice-sampler}\left(y \mid s, l\right),
l_n &\sim \operatorname{Uniform}\left(l; \; y_{n-1} - s_{n-1}/2,\, y_{n-1} + s_{n-1}/2\right) \\
s_n &\sim p(s \mid y_{n-1}, l_{n}) \\
y_n &\sim \operatorname{shrinkage}\left(y \mid s_n, l_n\right),
\end{aligned}
```
where $$y$$ is updated using the shrinkage procedure by Neal[^N2003] using the initial interval formed by $$(s, l)$$.
Expand Down
13 changes: 12 additions & 1 deletion src/SliceSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using AbstractMCMC
using Accessors
using Distributions
using FillArrays
using LinearAlgebra
using LogDensityProblems
using SimpleUnPack
using Random
Expand Down Expand Up @@ -48,6 +49,13 @@ Return the initial sample for the `model` using the random number generator `rng
"""
function initial_sample end

function exceeded_max_prop(max_prop::Int)
error("Exceeded maximum number of proposal $(max_prop).\n",
"Here are possible causes:\n",
"- The model might be broken or pathologic.\n",
"- There might be a bug in the sampler.")
end

# Univariate Slice Sampling Algorithms
export Slice, SliceSteppingOut, SliceDoublingOut

Expand All @@ -64,9 +72,12 @@ include("doublingout.jl")

# Latent Slice Sampling
export LatentSlice

include("latent.jl")

# Gibbsian Polar Slice Sampling
export GibbsPolarSlice
include("gibbspolar.jl")

# Turing Compatibility

if !isdefined(Base, :get_extension)
Expand Down
21 changes: 15 additions & 6 deletions src/doublingout.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@

"""
SliceDoublingOut(max_doubling_out, window)
SliceDoublingOut(window)
SliceDoublingOut(window; max_doubling_out, max_proposals)

Univariate slice sampling by automatically adapting the initial interval through the "doubling-out" procedure (Scheme 4 by Neal[^N2003])

# Fields
- `max_doubling_out`: Maximum number of "doubling outs" (default: 8).
# Arguments
- `window::Union{<:Real, <:AbstractVector}`: Proposal window.

# Keyword Arguments
- `max_doubling_out`: Maximum number of "doubling outs" (default: 8).
- `max_proposals::Int`: Maximum number of proposals allowed until throwing an error (default: `typemax(Int)`).
"""
struct SliceDoublingOut{W <: Union{<:AbstractVector, <:Real}} <: AbstractGibbsSliceSampling
max_doubling_out::Int
window ::W
max_doubling_out::Int
max_proposals ::Int
end

SliceDoublingOut(window::Union{<:AbstractVector, <:Real}) = SliceDoublingOut(8, window)
function SliceDoublingOut(
window ::Union{<:AbstractVector, <:Real};
max_doubling_out::Int = 8,
max_proposals ::Int = typemax(Int),
)
SliceDoublingOut(window, max_doubling_out, max_proposals)
end

function find_interval(
rng ::Random.AbstractRNG,
Expand Down
11 changes: 6 additions & 5 deletions src/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ function AbstractMCMC.step(
state ::GibbsSliceState;
kwargs...,
)
max_prop = sampler.max_proposals
logdensitymodel = model.logdensity
w = if sampler.window isa Real
Fill(sampler.window, LogDensityProblems.dimension(logdensitymodel))
Expand All @@ -48,15 +49,15 @@ function AbstractMCMC.step(
θ = copy(state.transition.params)
@assert length(w) == length(θ) "window size does not match parameter size"

total_props = 0
n_props = zeros(Int, length(θ))
for idx in shuffle(rng, 1:length(θ))
model_gibbs = GibbsObjective(logdensitymodel, idx, θ)
θ′idx, ℓp, props = slice_sampling_univariate(
rng, sampler, model_gibbs, w[idx], ℓp, θ[idx]
rng, sampler, model_gibbs, w[idx], ℓp, θ[idx], max_prop,
)
total_props += props
θ[idx] = θ′idx
n_props[idx] = props
θ[idx] = θ′idx
end
t = Transition(θ, ℓp, (num_proposals=total_props,))
t = Transition(θ, ℓp, (num_proposals=n_props,))
t, GibbsSliceState(t)
end
Loading
Loading