-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding implementation of Rstar statistic (#238)
* Added implementation of Rstar statistic * extracted method to work with arrays instead, added missing import * Update src/rstar.jl Co-authored-by: David Widmann <[email protected]> * Update src/rstar.jl Co-authored-by: David Widmann <[email protected]> * Update src/rstar.jl Co-authored-by: David Widmann <[email protected]> * Added more tests, minor changes and fixed Algo. 2 * Update src/rstar.jl Co-authored-by: David Widmann <[email protected]> * Changed code to use MLJ interface. * fixed bug * Update src/rstar.jl Co-authored-by: David Widmann <[email protected]> * minor code improvements, added README entry * added Ref * Update src/MCMCChains.jl Co-authored-by: David Widmann <[email protected]> * Update src/rstar.jl Co-authored-by: David Widmann <[email protected]> * Update src/rstar.jl Co-authored-by: David Widmann <[email protected]> * Update src/rstar.jl Co-authored-by: David Widmann <[email protected]> * final updates * Update Project.toml Co-authored-by: David Widmann <[email protected]> * unified test Co-authored-by: David Widmann <[email protected]>
- Loading branch information
1 parent
c560033
commit 4da661d
Showing
7 changed files
with
180 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
rstar([rng ,] classif::Supervised, chains::Chains; kwargs...) | ||
rstar([rng ,] classif::Supervised, x::AbstractMatrix, y::AbstractVector; kwargs...) | ||
Compute the R* convergence diagnostic of MCMC. | ||
This implementation is an adaption of Algorithm 1 & 2, described in [Lambert & Vehtari]. Note that the correctness of the statistic depends on the convergence of the classifier used internally in the statistic. You can track if the training of the classifier converged by inspection of the printed RMSE values from the XGBoost backend. To adjust the number of iterations used to train the classifier set `niter` accordingly. | ||
# Keyword Arguments | ||
* `subset = 0.8` ... Subset used to train the classifier, i.e. 0.8 implies 80% of the samples are used. | ||
* `iterations = 10` ... Number of iterations used to estimate the statistic. If the classifier is not probabilistic, i.e. does not return class probabilities, it is advisable to use a value of one. | ||
* `verbosity = 0` ... Verbosity level used during fitting of the classifier. | ||
# Usage | ||
```julia | ||
using MLJ, MLJModels | ||
# You need to load MLJBase and the respective package your are using for classification first. | ||
# Select a classifier to compute the Rstar statistic. | ||
# For example the XGBoost classifier. | ||
classif = @load XGBoostClassifier() | ||
# Compute 100 samples of the R* statistic using sampling from according to the prediction probabilities. | ||
Rs = rstar(classif, chn, iterations = 20) | ||
# estimate Rstar | ||
R = mean(Rs) | ||
# visualize distribution | ||
histogram(Rs) | ||
``` | ||
## References: | ||
[Lambert & Vehtari] Ben Lambert and Aki Vehtari. "R∗: A robust MCMC convergence diagnostic with uncertainty using gradient-boostined machines." Arxiv 2020. | ||
""" | ||
function rstar(rng::Random.AbstractRNG, classif::MLJModelInterface.Supervised, x::AbstractMatrix, y::AbstractVector{Int}; iterations = 10, subset = 0.8, verbosity = 0) | ||
|
||
size(x,1) != length(y) && throw(DimensionMismatch()) | ||
iterations >= 1 && ArgumentError("Number of iterations has to be positive!") | ||
|
||
if iterations > 1 && classif isa MLJModelInterface.Deterministic | ||
@warn("Classifier is not a probabilistic classifier but number of iterations is > 1.") | ||
elseif iterations == 1 && classif isa MLJModelInterface.Probabilistic | ||
@warn("Classifier is probabilistic but number of iterations is equal to one.") | ||
end | ||
|
||
N = length(y) | ||
K = length(unique(y)) | ||
|
||
# randomly sub-select training and testing set | ||
Ntrain = round(Int, N*subset) | ||
Ntest = N - Ntrain | ||
|
||
ids = Random.randperm(rng, N) | ||
train_ids = view(ids, 1:Ntrain) | ||
test_ids = view(ids, (Ntrain+1):N) | ||
|
||
# train classifier using XGBoost | ||
fitresult, _ = MLJModelInterface.fit(classif, verbosity, Tables.table(x[train_ids,:]), MLJModelInterface.categorical(y[train_ids])) | ||
|
||
xtest = Tables.table(x[test_ids,:]) | ||
ytest = view(y, test_ids) | ||
|
||
Rstats = map(i -> K*rstar_score(rng, classif, fitresult, xtest, ytest), 1:iterations) | ||
return Rstats | ||
end | ||
|
||
function rstar(classif::MLJModelInterface.Supervised, x::AbstractMatrix, y::AbstractVector{Int}; kwargs...) | ||
rstar(Random.GLOBAL_RNG, classif, x, y; kwargs...) | ||
end | ||
|
||
function rstar(classif::MLJModelInterface.Supervised, chn::Chains; kwargs...) | ||
return rstar(Random.GLOBAL_RNG, classif, chn; kwargs...) | ||
end | ||
|
||
function rstar(rng::Random.AbstractRNG, classif::MLJModelInterface.Supervised, chn::Chains; kwargs...) | ||
nchains = size(chn, 3) | ||
nchains <= 1 && throw(DimensionMismatch()) | ||
|
||
# collect data | ||
x = Array(chn) | ||
y = repeat(chains(chn); inner = size(chn,1)) | ||
|
||
return rstar(rng, classif, x, y; kwargs...) | ||
end | ||
|
||
function rstar_score(rng::Random.AbstractRNG, classif::MLJModelInterface.Probabilistic, fitresult, xtest, ytest) | ||
pred = get.(rand.(Ref(rng), MLJModelInterface.predict(classif, fitresult, xtest))) | ||
return mean(((p,y),) -> p == y, zip(pred, ytest)) | ||
end | ||
|
||
function rstar_score(rng::Random.AbstractRNG, classif::MLJModelInterface.Deterministic, fitresult, xtest, ytest) | ||
pred = MLJModelInterface.predict(classif, fitresult, xtest) | ||
return mean(((p,y),) -> p == y, zip(pred, ytest)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
using MCMCChains | ||
using Tables | ||
using MLJ, MLJModels | ||
using Test | ||
|
||
N = 1000 | ||
val = rand(N, 8, 4) | ||
colnames = ["a", "b", "c", "d", "e", "f", "g", "h"] | ||
internal_colnames = ["c", "d", "e", "f", "g", "h"] | ||
chn = Chains(val, colnames, Dict(:internals => internal_colnames)) | ||
|
||
classif = @load XGBoostClassifier() | ||
|
||
@testset "R star test" begin | ||
|
||
# Compute R* statistic for a mixed chain. | ||
R = rstar(classif, randn(N,2), rand(1:3,N)) | ||
|
||
# Resulting R value should be close to one, i.e. the classifier does not perform better than random guessing. | ||
@test mean(R) ≈ 1 atol=0.1 | ||
|
||
# Compute R* statistic for a mixed chain. | ||
R = rstar(classif, chn) | ||
|
||
# Resulting R value should be close to one, i.e. the classifier does not perform better than random guessing. | ||
@test mean(R) ≈ 1 atol=0.1 | ||
|
||
# Compute R* statistic for a non-mixed chain. | ||
niter = 1000 | ||
val = hcat(sin.(1:niter), cos.(1:niter)) | ||
val = cat(val, hcat(cos.(1:niter)*100, sin.(1:niter)*100), dims=3) | ||
chn_notmixed = Chains(val) | ||
|
||
# Restuling R value should be close to two, i.e. the classifier should be able to learn an almost perfect decision boundary between chains. | ||
R = rstar(classif, chn_notmixed) | ||
@test mean(R) ≈ 2 atol=0.1 | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4da661d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
4da661d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/21527
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: