-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add compatibility with MCMCDiagnosticTools v0.3 (#401)
* Bump MCMCDiagnosticTools compat * Update imported/exported methods * Remove type constraint on classifier * Overload and export mcse * Overload and update ess and rhat * Update summarystats * Update tests * Increment major version * Rename ess.jl to ess_rhat.jl * Add back ess_per_sec * Fix bug constructing ess_per_sec * Update ess_rhat tests * Test mcse * Update docs * Remove deprecations * Remove unused import * Revert "Fix MLJDecisionTreeInterface to 0.3.0 (#402)" This reverts commit 991f10b. * Always include ess_per_sec in table * Use isequal to pass with missing values * Use isequal for missing * Remove naive_se Fixes #351 * Test Tables interface before loading StatsPlots DataValues (a StatsPlots dependency) pirates a convert method that causes the Tables equality tests with `missing` to fail. See https://github.com/queryverse/DataValues.jl * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> --------- Co-authored-by: David Widmann <[email protected]>
- Loading branch information
Showing
16 changed files
with
268 additions
and
151 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ Pages = [ | |
"heideldiag.jl", | ||
"rafterydiag.jl", | ||
"rstar.jl", | ||
"ess.jl" | ||
"ess_rhat.jl", | ||
"mcse.jl", | ||
] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
""" | ||
ess(chains::Chains; duration=compute_duration, kwargs...) | ||
Estimate the effective sample size. | ||
ESS per second options include `duration=MCMCChains.compute_duration` (the default) | ||
and `duration=MCMCChains.wall_duration`. | ||
""" | ||
function MCMCDiagnosticTools.ess( | ||
chains::Chains; | ||
sections = _default_sections(chains), duration = compute_duration, kwargs... | ||
) | ||
# Subset the chain | ||
_chains = Chains(chains, _clean_sections(chains, sections)) | ||
|
||
# Estimate the effective sample size | ||
ess = MCMCDiagnosticTools.ess( | ||
_permutedims_diagnostics(_chains.value.data); | ||
kwargs..., | ||
) | ||
|
||
# Calculate ESS/minute if available | ||
dur = duration(chains) | ||
|
||
# Convert to NamedTuple | ||
ess_per_sec = ess ./ dur | ||
nt = merge((parameters = names(_chains),), (; ess, ess_per_sec)) | ||
|
||
return ChainDataFrame("ESS", nt) | ||
end | ||
|
||
""" | ||
rhat(chains::Chains; kwargs...) | ||
Estimate the ``\\widehat{R}`` diagnostic. | ||
""" | ||
function MCMCDiagnosticTools.rhat( | ||
chains::Chains; | ||
sections = _default_sections(chains), kwargs... | ||
) | ||
# Subset the chain | ||
_chains = Chains(chains, _clean_sections(chains, sections)) | ||
|
||
# Estimate the rhat | ||
rhat = MCMCDiagnosticTools.rhat( | ||
_permutedims_diagnostics(_chains.value.data); | ||
kwargs..., | ||
) | ||
|
||
# Convert to NamedTuple | ||
nt = merge((parameters = names(_chains),), (; rhat)) | ||
|
||
return ChainDataFrame("R-hat", nt) | ||
end | ||
|
||
""" | ||
ess_rhat(chains::Chains; duration=compute_duration, kwargs...) | ||
Estimate the effective sample size and the ``\\widehat{R}`` diagnostic | ||
ESS per second options include `duration=MCMCChains.compute_duration` (the default) | ||
and `duration=MCMCChains.wall_duration`. | ||
""" | ||
function MCMCDiagnosticTools.ess_rhat( | ||
chains::Chains; | ||
sections = _default_sections(chains), duration = compute_duration, kwargs... | ||
) | ||
# Subset the chain | ||
_chains = Chains(chains, _clean_sections(chains, sections)) | ||
|
||
# Estimate the effective sample size and rhat | ||
ess_rhat = MCMCDiagnosticTools.ess_rhat( | ||
_permutedims_diagnostics(_chains.value.data); | ||
kwargs..., | ||
) | ||
|
||
# Calculate ESS/minute if available | ||
dur = duration(chains) | ||
|
||
# Convert to NamedTuple | ||
ess_per_sec = ess_rhat.ess ./ dur | ||
nt = merge((parameters = names(_chains),), ess_rhat, (; ess_per_sec)) | ||
|
||
return ChainDataFrame("ESS/R-hat", nt) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
""" | ||
mcse(chains::Chains; duration=compute_duration, kwargs...) | ||
Estimate the Monte Carlo standard error. | ||
""" | ||
function MCMCDiagnosticTools.mcse( | ||
chains::Chains; | ||
sections = _default_sections(chains), kwargs... | ||
) | ||
# Subset the chain | ||
_chains = Chains(chains, _clean_sections(chains, sections)) | ||
|
||
# Estimate the effective sample size | ||
mcse = MCMCDiagnosticTools.mcse( | ||
_permutedims_diagnostics(_chains.value.data); | ||
kwargs..., | ||
) | ||
|
||
nt = merge((parameters = names(_chains),), (; mcse)) | ||
|
||
return ChainDataFrame("MCSE", nt) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
ddac60f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
ddac60f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/78718
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: