Skip to content

Commit

Permalink
Adding implementation of Rstar statistic (#238)
Browse files Browse the repository at this point in the history
* Added implementation of Rstar statistic

* extracted method to work with arrays instead, added missing import

* Update src/rstar.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/rstar.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/rstar.jl

Co-authored-by: David Widmann <[email protected]>

* Added more tests, minor changes and fixed Algo. 2

* Update src/rstar.jl

Co-authored-by: David Widmann <[email protected]>

* Changed code to use MLJ interface.

* fixed bug

* Update src/rstar.jl

Co-authored-by: David Widmann <[email protected]>

* minor code improvements, added README entry

* added Ref

* Update src/MCMCChains.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/rstar.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/rstar.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/rstar.jl

Co-authored-by: David Widmann <[email protected]>

* final updates

* Update Project.toml

Co-authored-by: David Widmann <[email protected]>

* unified test

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
trappmartin and devmotion authored Sep 17, 2020
1 parent c560033 commit 4da661d
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 6 deletions.
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "Chain types and utility functions for MCMC simulations."
version = "4.1.0"
version = "4.2.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -14,6 +14,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -33,6 +34,7 @@ Compat = "2.2, 3"
Distributions = "0.21, 0.22, 0.23"
Formatting = "0.4"
IteratorInterfaceExtensions = "0.1.1, 1"
MLJModelInterface = "0.3.5"
NaturalSort = "1"
PrettyTables = "0.9"
RecipesBase = "0.7, 0.8, 1.0"
Expand All @@ -47,9 +49,12 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
XGBoost = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"

[targets]
test = ["DataFrames", "FFTW", "KernelDensity", "Logging", "StatsPlots", "Test", "UnicodePlots"]
test = ["DataFrames", "FFTW", "KernelDensity", "Logging", "StatsPlots", "Test", "UnicodePlots", "MLJ", "MLJModels", "XGBoost"]
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,32 @@ heideldiag(c::Chains; alpha=0.05, eps=0.1, etype=:imse)
rafterydiag(c::Chains; q=0.025, r=0.005, s=0.95, eps=0.001)
```

#### Rstar Diagnostic
Rstar diagnostic described in [https://arxiv.org/pdf/2003.07900.pdf](https://arxiv.org/pdf/2003.07900.pdf).
Note that the use requires MLJ and MLJModels to be installed.

Usage:

```julia
using MLJ, MLJModels

chn ... # sampling results of multiple chains

# select classifier used to compute the diagnostic
classif = @load XGBoostClassifier

# estimate diagnostic
Rs = rstar(chn, classif)
R = mean(Rs)

# visualize distribution
using Plots
histogram(Rs)
```

See `? rstar` for more details.


### Model Selection
#### Deviance Information Criterion (DIC)
```julia
Expand Down
5 changes: 4 additions & 1 deletion src/MCMCChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using SpecialFunctions
using Formatting
import StatsBase: autocov, counts, sem, AbstractWeights,
autocor, describe, quantile, sample, summarystats, cov

import MLJModelInterface
import NaturalSort
import PrettyTables
import Tables
Expand All @@ -36,6 +36,8 @@ export summarize
export discretediag, gelmandiag, gewekediag, heideldiag, rafterydiag
export hpd, ess

export rstar

export ESSMethod, FFTESSMethod, BDAESSMethod

"""
Expand Down Expand Up @@ -73,5 +75,6 @@ include("stats.jl")
include("modelstats.jl")
include("plot.jl")
include("tables.jl")
include("rstar.jl")

end # module
96 changes: 96 additions & 0 deletions src/rstar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
rstar([rng ,] classif::Supervised, chains::Chains; kwargs...)
rstar([rng ,] classif::Supervised, x::AbstractMatrix, y::AbstractVector; kwargs...)
Compute the R* convergence diagnostic of MCMC.
This implementation is an adaption of Algorithm 1 & 2, described in [Lambert & Vehtari]. Note that the correctness of the statistic depends on the convergence of the classifier used internally in the statistic. You can track if the training of the classifier converged by inspection of the printed RMSE values from the XGBoost backend. To adjust the number of iterations used to train the classifier set `niter` accordingly.
# Keyword Arguments
* `subset = 0.8` ... Subset used to train the classifier, i.e. 0.8 implies 80% of the samples are used.
* `iterations = 10` ... Number of iterations used to estimate the statistic. If the classifier is not probabilistic, i.e. does not return class probabilities, it is advisable to use a value of one.
* `verbosity = 0` ... Verbosity level used during fitting of the classifier.
# Usage
```julia
using MLJ, MLJModels
# You need to load MLJBase and the respective package your are using for classification first.
# Select a classifier to compute the Rstar statistic.
# For example the XGBoost classifier.
classif = @load XGBoostClassifier()
# Compute 100 samples of the R* statistic using sampling from according to the prediction probabilities.
Rs = rstar(classif, chn, iterations = 20)
# estimate Rstar
R = mean(Rs)
# visualize distribution
histogram(Rs)
```
## References:
[Lambert & Vehtari] Ben Lambert and Aki Vehtari. "R∗: A robust MCMC convergence diagnostic with uncertainty using gradient-boostined machines." Arxiv 2020.
"""
function rstar(rng::Random.AbstractRNG, classif::MLJModelInterface.Supervised, x::AbstractMatrix, y::AbstractVector{Int}; iterations = 10, subset = 0.8, verbosity = 0)

size(x,1) != length(y) && throw(DimensionMismatch())
iterations >= 1 && ArgumentError("Number of iterations has to be positive!")

if iterations > 1 && classif isa MLJModelInterface.Deterministic
@warn("Classifier is not a probabilistic classifier but number of iterations is > 1.")
elseif iterations == 1 && classif isa MLJModelInterface.Probabilistic
@warn("Classifier is probabilistic but number of iterations is equal to one.")
end

N = length(y)
K = length(unique(y))

# randomly sub-select training and testing set
Ntrain = round(Int, N*subset)
Ntest = N - Ntrain

ids = Random.randperm(rng, N)
train_ids = view(ids, 1:Ntrain)
test_ids = view(ids, (Ntrain+1):N)

# train classifier using XGBoost
fitresult, _ = MLJModelInterface.fit(classif, verbosity, Tables.table(x[train_ids,:]), MLJModelInterface.categorical(y[train_ids]))

xtest = Tables.table(x[test_ids,:])
ytest = view(y, test_ids)

Rstats = map(i -> K*rstar_score(rng, classif, fitresult, xtest, ytest), 1:iterations)
return Rstats
end

function rstar(classif::MLJModelInterface.Supervised, x::AbstractMatrix, y::AbstractVector{Int}; kwargs...)
rstar(Random.GLOBAL_RNG, classif, x, y; kwargs...)
end

function rstar(classif::MLJModelInterface.Supervised, chn::Chains; kwargs...)
return rstar(Random.GLOBAL_RNG, classif, chn; kwargs...)
end

function rstar(rng::Random.AbstractRNG, classif::MLJModelInterface.Supervised, chn::Chains; kwargs...)
nchains = size(chn, 3)
nchains <= 1 && throw(DimensionMismatch())

# collect data
x = Array(chn)
y = repeat(chains(chn); inner = size(chn,1))

return rstar(rng, classif, x, y; kwargs...)
end

function rstar_score(rng::Random.AbstractRNG, classif::MLJModelInterface.Probabilistic, fitresult, xtest, ytest)
pred = get.(rand.(Ref(rng), MLJModelInterface.predict(classif, fitresult, xtest)))
return mean(((p,y),) -> p == y, zip(pred, ytest))
end

function rstar_score(rng::Random.AbstractRNG, classif::MLJModelInterface.Deterministic, fitresult, xtest, ytest)
pred = MLJModelInterface.predict(classif, fitresult, xtest)
return mean(((p,y),) -> p == y, zip(pred, ytest))
end
5 changes: 3 additions & 2 deletions test/diagnostic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
end

@testset "function tests" begin
tchain = Chains(rand(n_iter, n_name, n_chain), ["a", "b", "c"], Dict(:internals => ["c"]))
tchain = Chains(rand(niter, nparams, nchains), ["a", "b", "c"], Dict(:internals => ["c"]))

# the following tests only check if the function calls work!
@test MCMCChains.diag_all(rand(50, 2), :weiss, 1, 1, 1) != nothing
Expand Down Expand Up @@ -137,9 +137,10 @@ end
end

@testset "sorting" begin
chn_unsorted = Chains(rand(100,3,1), ["2", "1", "3"])
chn_unsorted = Chains(rand(100, nparams, 1), ["2", "1", "3"])
chn_sorted = sort(chn_unsorted)

@test names(chn_sorted) == Symbol.([1, 2, 3])
@test names(chn_unsorted) == Symbol.([2, 1, 3])
end

37 changes: 37 additions & 0 deletions test/rstar_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using MCMCChains
using Tables
using MLJ, MLJModels
using Test

N = 1000
val = rand(N, 8, 4)
colnames = ["a", "b", "c", "d", "e", "f", "g", "h"]
internal_colnames = ["c", "d", "e", "f", "g", "h"]
chn = Chains(val, colnames, Dict(:internals => internal_colnames))

classif = @load XGBoostClassifier()

@testset "R star test" begin

# Compute R* statistic for a mixed chain.
R = rstar(classif, randn(N,2), rand(1:3,N))

# Resulting R value should be close to one, i.e. the classifier does not perform better than random guessing.
@test mean(R) 1 atol=0.1

# Compute R* statistic for a mixed chain.
R = rstar(classif, chn)

# Resulting R value should be close to one, i.e. the classifier does not perform better than random guessing.
@test mean(R) 1 atol=0.1

# Compute R* statistic for a non-mixed chain.
niter = 1000
val = hcat(sin.(1:niter), cos.(1:niter))
val = cat(val, hcat(cos.(1:niter)*100, sin.(1:niter)*100), dims=3)
chn_notmixed = Chains(val)

# Restuling R value should be close to two, i.e. the classifier should be able to learn an almost perfect decision boundary between chains.
R = rstar(classif, chn_notmixed)
@test mean(R) 2 atol=0.1
end
8 changes: 7 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
using Test

@testset "MCMCChains" begin

# run tests related to rstar statistic
println("Rstar")
@time include("rstar_tests.jl")

# run tests for effective sample size
include("ess_tests.jl")
println("ESS")
@time include("ess_tests.jl")

# run plotting tests
println("Plotting")
Expand Down

2 comments on commit 4da661d

@cpfiffer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/21527

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v4.2.0 -m "<description of version>" 4da661dd60b396e6166be8c249e59f6037d3d556
git push origin v4.2.0

Please sign in to comment.