From dfbf1c918271f1ef3e46ca30754466a8b5b37b9a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 3 Oct 2024 15:54:40 +0200 Subject: [PATCH 1/3] Use `n_adapts` instead of `nadapts` --- README.md | 2 +- src/abstractmcmc.jl | 17 ++++++++++++++--- test/abstractmcmc.jl | 24 ++++++++++++++++++++++++ test/mcmcchains.jl | 2 +- 4 files changed, 40 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index c3a9bf93..f16830b7 100644 --- a/README.md +++ b/README.md @@ -128,7 +128,7 @@ samples = AbstractMCMC.sample( model, sampler, n_adapts + n_samples; - nadapts = n_adapts, + n_adapts = n_adapts, initial_params = initial_θ, ) ``` diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 96ec9c7a..1b8eea19 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -47,6 +47,10 @@ function AbstractMCMC.sample( callback = nothing, kwargs..., ) + if haskey(kwargs, :nadapts) + throw(ArgumentError("keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.")) + end + if callback === nothing callback = HMCProgressCallback(N, progress = progress, verbose = verbose) progress = false # don't use AMCMC's progress-funtionality @@ -78,6 +82,9 @@ function AbstractMCMC.sample( callback = nothing, kwargs..., ) + if haskey(kwargs, :nadapts) + throw(ArgumentError("keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.")) + end if callback === nothing callback = HMCProgressCallback(N, progress = progress, verbose = verbose) @@ -144,6 +151,10 @@ function AbstractMCMC.step( n_adapts::Int = 0, kwargs..., ) + if haskey(kwargs, :nadapts) + throw(ArgumentError("keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.")) + end + # Compute transition. i = state.i + 1 t_old = state.transition @@ -200,7 +211,7 @@ function HMCProgressCallback(n_samples; progress = true, verbose = false) HMCProgressCallback(pm, progress, verbose, Ref(0), Ref(0)) end -function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; nadapts = 0, kwargs...) +function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; n_adapts::Int = 0, kwargs...) progress = cb.progress verbose = cb.verbose pm = cb.pm @@ -243,8 +254,8 @@ function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; nadapts = 0, kw ), ) # Report finish of adapation - elseif verbose && isadapted && i == nadapts - @info "Finished $nadapts adapation steps" adaptor κ.τ.integrator metric + elseif verbose && isadapted && i == n_adapts + @info "Finished $(n_adapts) adapation steps" adaptor κ.τ.integrator metric end end diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl index 242c9f29..25359cd6 100644 --- a/test/abstractmcmc.jl +++ b/test/abstractmcmc.jl @@ -32,6 +32,30 @@ using Statistics: mean verbose = false, ) + # Error if keyword argument `nadapts` is used + @test_throws ArgumentError AbstractMCMC.sample( + rng, + model, + nuts, + n_adapts + n_samples; + nadapts = n_adapts, + initial_params = θ_init, + progress = false, + verbose = false, + ) + @test_throws ArgumentError AbstractMCMC.sample( + rng, + model, + nuts, + MCMCThreads(), + n_adapts + n_samples, + 2; + nadapts = n_adapts, + initial_params = θ_init, + progress = false, + verbose = false, + ) + # Transform back to original space. # NOTE: We're not correcting for the `logabsdetjac` here since, but # we're only interested in the mean it doesn't matter. diff --git a/test/mcmcchains.jl b/test/mcmcchains.jl index da470842..1f868578 100644 --- a/test/mcmcchains.jl +++ b/test/mcmcchains.jl @@ -23,7 +23,7 @@ using Statistics: mean model, sampler, n_adapts + n_samples; - nadapts = n_adapts, + n_adapts = n_adapts, initial_params = θ_init, chain_type = Chains, progress = false, From ff71b1fce1a18ac75c7a87fabcbeb3021ffaa87e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 3 Oct 2024 16:44:02 +0200 Subject: [PATCH 2/3] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index efdaf2a4..2d471834 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -version = "0.6.1" +version = "0.6.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From a186dbb6e7b6e1cdd508e84b32453867f407337a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 3 Oct 2024 18:21:46 +0200 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/abstractmcmc.jl | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 1b8eea19..302d927e 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -48,7 +48,11 @@ function AbstractMCMC.sample( kwargs..., ) if haskey(kwargs, :nadapts) - throw(ArgumentError("keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.")) + throw( + ArgumentError( + "keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.", + ), + ) end if callback === nothing @@ -83,7 +87,11 @@ function AbstractMCMC.sample( kwargs..., ) if haskey(kwargs, :nadapts) - throw(ArgumentError("keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.")) + throw( + ArgumentError( + "keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.", + ), + ) end if callback === nothing @@ -152,7 +160,11 @@ function AbstractMCMC.step( kwargs..., ) if haskey(kwargs, :nadapts) - throw(ArgumentError("keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.")) + throw( + ArgumentError( + "keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.", + ), + ) end # Compute transition. @@ -211,7 +223,16 @@ function HMCProgressCallback(n_samples; progress = true, verbose = false) HMCProgressCallback(pm, progress, verbose, Ref(0), Ref(0)) end -function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; n_adapts::Int = 0, kwargs...) +function (cb::HMCProgressCallback)( + rng, + model, + spl, + t, + state, + i; + n_adapts::Int = 0, + kwargs..., +) progress = cb.progress verbose = cb.verbose pm = cb.pm