Skip to content

Commit

Permalink
Fix ESS (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jul 31, 2020

Verified

This commit was signed with the committer’s verified signature.
JMazurkiewicz Jakub Mazurkiewicz
1 parent 53c87c2 commit 55d5ef7
Showing 3 changed files with 110 additions and 101 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "Chain types and utility functions for MCMC simulations."
version = "4.0.1"
version = "4.0.2"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
168 changes: 87 additions & 81 deletions src/ess.jl
Original file line number Diff line number Diff line change
@@ -8,9 +8,11 @@ The `ESSMethod` uses a standard algorithm for estimating the
effective sample size of MCMC chains.
It is is based on the discussion by
[Vehtari et al. (2019)](https://arxiv.org/pdf/1903.08008.pdf) and uses the
plug-in estimator of the autocorrelation function discussed by
[Vehtari et al. (2019)](https://arxiv.org/pdf/1903.08008.pdf) and uses a
biased estimator of the autocovariance, as discussed by
[Geyer (1992)](https://projecteuclid.org/euclid.ss/1177011137).
In contrast to Geyer, the divisor `n - 1` is used in the estimation of
the autocovariance to obtain the unbiased estimator of the variance for lag 0.
"""
struct ESSMethod <: AbstractESSMethod end

@@ -22,10 +24,13 @@ the effective sample size of MCMC chains.
It is is based on the discussion by
[Vehtari et al. (2019)](https://arxiv.org/pdf/1903.08008.pdf) and uses the
plug-in estimator of the autocorrelation function discussed by
[Geyer (1992)](https://projecteuclid.org/euclid.ss/1177011137). In contrast to
`ESSMethod`, it uses fast Fourier transforms (FFTs) for
estimating the autocorrelation function.
biased estimator of the autocovariance, as discussed by
[Geyer (1992)](https://projecteuclid.org/euclid.ss/1177011137).
In contrast to Geyer, the divisor `n - 1` is used in the estimation of
the autocovariance to obtain the unbiased estimator of the variance for lag 0.
In contrast to [`ESSMethod`](@ref), this method uses fast Fourier transforms
(FFTs) for estimating the autocorrelation.
"""
struct FFTESSMethod <: AbstractESSMethod end

@@ -37,27 +42,29 @@ MCMC chains.
It is is based on the discussion by
[Vehtari et al. (2019)](https://arxiv.org/pdf/1903.08008.pdf) and uses the
estimator of the autocorrelation function discussed in
variogram estimator of the autocorrelation function discussed in
[Bayesian Data Analysis (2013)](https://www.taylorfrancis.com/books/9780429113079).
"""
struct BDAESSMethod <: AbstractESSMethod end

# caches
mutable struct ESSCache{T}
struct ESSCache{T,S}
samples::Matrix{T}
var::Vector{T}
chain_var::Vector{S}
end

mutable struct FFTESSCache{T,P,I}
A::T
struct FFTESSCache{T,S,C,P,I}
samples::Matrix{T}
chain_var::Vector{S}
samples_cache::C
plan::P
invplan::I
niter::Int
end

mutable struct BDAESSCache{T}
mutable struct BDAESSCache{T,S,M}
samples::Matrix{T}
var::Vector{T}
chain_var::Vector{S}
mean_chain_var::M
end

function build_cache(::ESSMethod, samples::Matrix, var::Vector)
@@ -76,114 +83,116 @@ function build_cache(::FFTESSMethod, samples::Matrix, var::Vector)
# create cache for FFT
T = complex(eltype(samples))
n = nextprod([2, 3], 2 * niter - 1)
A = Matrix{T}(undef, n, nchains)
samples_cache = Matrix{T}(undef, n, nchains)

# create plans of FFTs
fft_plan = plan_fft!(A, 1)
ifft_plan = plan_ifft!(A, 1)
fft_plan = plan_fft!(samples_cache, 1)
ifft_plan = plan_ifft!(samples_cache, 1)

return FFTESSCache(A, fft_plan, ifft_plan, niter)
return FFTESSCache(samples, var, samples_cache, fft_plan, ifft_plan)
end

function build_cache(::BDAESSMethod, samples::Matrix, var::Vector)
# check arguments
nchains = size(samples, 2)
length(var) == nchains || throw(DimensionMismatch())

return BDAESSCache(samples, var)
return BDAESSCache(samples, var, mean(var))
end

update_cache!(cache, samples::Matrix, var::Vector) = nothing
update!(cache::ESSCache) = nothing

function update_cache!(cache::FFTESSCache, samples::Matrix, var::Vector)
# check arguments
niter, nchains = size(samples)
niter == cache.niter || throw(DimensionMismatch())
A = cache.A
nchains == size(A, 2) || throw(DimensionMismatch())

function update!(cache::FFTESSCache)
# copy samples and add zero padding
n = size(A, 1)
T = eltype(A)
samples = cache.samples
samples_cache = cache.samples_cache
niter, nchains = size(samples)
n = size(samples_cache, 1)
T = eltype(samples_cache)
@inbounds for j in 1:nchains
for i in 1:niter
A[i, j] = samples[i, j]
samples_cache[i, j] = samples[i, j]
end
for i in (niter + 1):n
A[i, j] = zero(T)
samples_cache[i, j] = zero(T)
end
end

# compute unnormalized autocorrelation
cache.plan * A
@. A = abs2(A)
cache.invplan * A
# compute unnormalized autocovariance
cache.plan * samples_cache
@. samples_cache = abs2(samples_cache)
cache.invplan * samples_cache

nothing
end

# estimation of the autocorrelation function
function mean_autocorr(k::Int, cache::ESSCache)
function update!(cache::BDAESSCache)
# recompute mean of within-chain variances
cache.mean_chain_var = mean(cache.chain_var)

return
end

function mean_autocov(k::Int, cache::ESSCache)
# check arguments
samples = cache.samples
niter, nchains = size(samples)
0 k < niter || throw(ArgumentError("only lags ≥ 0 and < $niter are supported"))

# compute mean autocorrelation
var = cache.var
s = zero(eltype(var))
# compute mean of unnormalized autocovariance estimates
firstrange = 1:(niter - k)
lastrange = (k + 1):niter
@inbounds for i in 1:nchains
# increment unnormalized correlation estimates
s = mean(1:nchains) do i
if eltype(samples) isa BlasReal
# call into BLAS if possible
s += dot(samples, firstrange, samples, lastrange) / var[i]
x = dot(samples, firstrange, samples, lastrange)
firstrange = firstrange .+ niter
lastrange = lastrange .+ niter
return x
else
# otherwise use views
s += dot(view(samples, firstrange, i), view(samples, lastrange, i)) / var[i]
return dot(view(samples, firstrange, i), view(samples, lastrange, i))
end
end

return s / length(samples)
# normalize autocovariance estimators by `niter - 1` instead
# of `niter - k` to obtain
# - unbiased estimators of the variance for lag 0
# - biased but more stable estimators for all other lags as discussed by
# Geyer (1992)
return s / (niter - 1)
end

function mean_autocorr(k::Int, cache::FFTESSCache)
function mean_autocov(k::Int, cache::FFTESSCache)
# check arguments
niter = cache.niter
niter, nchains = size(cache.samples)
0 k < niter || throw(ArgumentError("only lags ≥ 0 and < $niter are supported"))

# compute mean autocorrelation
A = cache.A
nchains = size(A, 2)
s = zero(real(eltype(A)))
@inbounds for i in 1:nchains
s += real(A[k + 1, i]) / real(A[1, i])
# compute mean autocovariance
# we use biased but more stable estimators as discussed by Geyer (1992)
samples_cache = cache.samples_cache
chain_var = cache.chain_var
return mean(1:nchains) do i
real(samples_cache[k + 1, i]) / real(samples_cache[1, i]) * chain_var[i]
end

return s / nchains
end

function mean_autocorr(k::Int, cache::BDAESSCache)
function mean_autocov(k::Int, cache::BDAESSCache)
# check arguments
samples = cache.samples
niter, nchains = size(samples)
0 k < niter || throw(ArgumentError("only lags ≥ 0 and < $niter are supported"))

# compute mean autocorrelation
var = cache.var
s = zero(eltype(var))
@inbounds for j in 1:nchains
sj = zero(s)
for i in 1:(niter - k)
sj += abs2(samples[i, j] - samples[k + i, j])
# compute mean autocovariance
n = niter - k
idxs = 1:n
s = mean(1:nchains) do j
return sum(idxs) do i
abs2(samples[i, j] - samples[k + i, j])
end
s += sj / var[j]
end

return 1 - s / (2 * length(samples))
return cache.mean_chain_var - s / (2 * n)
end

"""
@@ -231,7 +240,7 @@ function ess_rhat(
samples = Array{T}(undef, niter, nchains)

# compute correction factor
correctionfactor = niter / (niter - 1)
correctionfactor = (niter - 1) / niter

# define cache for the computation of the autocorrelation
esscache = build_cache(method, samples, chain_var)
@@ -257,13 +266,12 @@ function ess_rhat(

# calculate within-chain variance
@inbounds for j in 1:nchains
chain_var[j] = var(view(samples, :, j); mean = chain_mean[j], corrected = false)
chain_var[j] = var(view(samples, :, j); mean = chain_mean[j], corrected = true)
end
mean_chain_var = mean(chain_var)
W = correctionfactor * mean_chain_var
W = mean(chain_var)

# compute variance estimator var₊, which accounts for between-chain variance as well
var₊ = mean_chain_var + var(chain_mean; corrected = true)
var₊ = correctionfactor * W + var(chain_mean; corrected = true)
inv_var₊ = inv(var₊)

# estimate the potential scale reduction
@@ -273,25 +281,23 @@ function ess_rhat(
samples .-= chain_mean

# update cache
update_cache!(esscache, samples, chain_var)
update!(esscache)

# compute the first two autocorrelation terms
mean_ρ = mean_autocorr(1, esscache)
ρ_odd = 1 - inv_var₊ * (W - mean_ρ)
ρ_even = one(ρ_odd)
# compute the first two autocorrelation estimates
# by combining autocorrelation (or rather autocovariance) estimates of each chain
ρ_odd = 1 - inv_var₊ * (W - mean_autocov(1, esscache))
ρ_even = one(ρ_odd) # estimate at lag 0 is known

# sum correlation estimates
pₜ = ρ_even + ρ_odd
sum_pₜ = pₜ

k = 2
while k < maxlag
# compute and combine autocorrelation of all chains
mean_ρ = mean_autocorr(k, esscache)
ρ_even = 1 - inv_var₊ * (W - mean_ρ)

mean_ρ = mean_autocorr(k + 1, esscache)
ρ_odd = 1 - inv_var₊ * (W - mean_ρ)
# compute subsequent autocorrelation of all chains
# by combining estimates of each chain
ρ_even = 1 - inv_var₊ * (W - mean_autocov(k, esscache))
ρ_odd = 1 - inv_var₊ * (W - mean_autocov(k + 1, esscache))

# stop summation if p becomes non-positive
Δ = ρ_even + ρ_odd
41 changes: 22 additions & 19 deletions test/ess_tests.jl
Original file line number Diff line number Diff line change
@@ -31,27 +31,30 @@ end
@testset "ESS and R̂ (IID samples)" begin
Random.seed!(20)

x = randn(10_000, 40, 10)

ess_standard, rhat_standard = MCMCChains.ess_rhat(x)
ess_standard2, rhat_standard2 = MCMCChains.ess_rhat(x; method = ESSMethod())
ess_fft, rhat_fft = MCMCChains.ess_rhat(x; method = FFTESSMethod())
ess_bda, rhat_bda = MCMCChains.ess_rhat(x; method = BDAESSMethod())

# check that we get (roughly) the same results
@test ess_standard == ess_standard2
@test ess_standard ess_fft
@test rhat_standard == rhat_standard2 == rhat_fft == rhat_bda

# check that the estimates are reasonable
@test all(x -> isapprox(x, 100_000; atol = 2_500), ess_standard)
@test all(x -> isapprox(x, 1; atol = 0.1), rhat_standard)
rawx = randn(10_000, 40, 10)

@test count(x -> !isapprox(x, 100_000; atol = 2_500), ess_bda) == 7
@test all(x -> isapprox(x, 1; atol = 0.1), rhat_bda)
# Repeat tests with different scales
for scale in (1, 50, 100)
x = scale * rawx

# BDA method fluctuates more
@test var(ess_standard) < var(ess_bda)
ess_standard, rhat_standard = MCMCChains.ess_rhat(x)
ess_standard2, rhat_standard2 = MCMCChains.ess_rhat(x; method = ESSMethod())
ess_fft, rhat_fft = MCMCChains.ess_rhat(x; method = FFTESSMethod())
ess_bda, rhat_bda = MCMCChains.ess_rhat(x; method = BDAESSMethod())

# check that we get (roughly) the same results
@test ess_standard == ess_standard2
@test ess_standard ess_fft
@test rhat_standard == rhat_standard2 == rhat_fft == rhat_bda

# check that the estimates are reasonable
@test all(x -> isapprox(x, 100_000; atol = 2_500), ess_standard)
@test all(x -> isapprox(x, 100_000; atol = 2_500), ess_bda)
@test all(x -> isapprox(x, 1; atol = 0.1), rhat_standard)

# BDA method fluctuates more
@test var(ess_standard) < var(ess_bda)
end
end

@testset "ESS and R̂ (identical samples)" begin

2 comments on commit 55d5ef7

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/18767

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v4.0.2 -m "<description of version>" 55d5ef70776d355c66c39d7abcb2889662f2987b
git push origin v4.0.2

Please sign in to comment.