diff --git a/Project.toml b/Project.toml index dc5f9e02..a3c51910 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -version = "0.6.1" +version = "0.6.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/README.md b/README.md index c3a9bf93..f16830b7 100644 --- a/README.md +++ b/README.md @@ -128,7 +128,7 @@ samples = AbstractMCMC.sample( model, sampler, n_adapts + n_samples; - nadapts = n_adapts, + n_adapts = n_adapts, initial_params = initial_θ, ) ``` diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 502763a8..24ce799c 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -47,6 +47,14 @@ function AbstractMCMC.sample( callback = nothing, kwargs..., ) + if haskey(kwargs, :nadapts) + throw( + ArgumentError( + "keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.", + ), + ) + end + if callback === nothing callback = HMCProgressCallback(N, progress = progress, verbose = verbose) progress = false # don't use AMCMC's progress-funtionality @@ -78,6 +86,13 @@ function AbstractMCMC.sample( callback = nothing, kwargs..., ) + if haskey(kwargs, :nadapts) + throw( + ArgumentError( + "keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.", + ), + ) + end if callback === nothing callback = HMCProgressCallback(N, progress = progress, verbose = verbose) @@ -144,6 +159,14 @@ function AbstractMCMC.step( n_adapts::Int = 0, kwargs..., ) + if haskey(kwargs, :nadapts) + throw( + ArgumentError( + "keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.", + ), + ) + end + # Compute transition. i = state.i + 1 t_old = state.transition @@ -200,7 +223,16 @@ function HMCProgressCallback(n_samples; progress = true, verbose = false) HMCProgressCallback(pm, progress, verbose, Ref(0), Ref(0)) end -function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; nadapts = 0, kwargs...) +function (cb::HMCProgressCallback)( + rng, + model, + spl, + t, + state, + i; + n_adapts::Int = 0, + kwargs..., +) progress = cb.progress verbose = cb.verbose pm = cb.pm @@ -243,8 +275,8 @@ function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; nadapts = 0, kw ), ) # Report finish of adapation - elseif verbose && isadapted && i == nadapts - @info "Finished $nadapts adapation steps" adaptor κ.τ.integrator metric + elseif verbose && isadapted && i == n_adapts + @info "Finished $(n_adapts) adapation steps" adaptor κ.τ.integrator metric end end diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl index 242c9f29..25359cd6 100644 --- a/test/abstractmcmc.jl +++ b/test/abstractmcmc.jl @@ -32,6 +32,30 @@ using Statistics: mean verbose = false, ) + # Error if keyword argument `nadapts` is used + @test_throws ArgumentError AbstractMCMC.sample( + rng, + model, + nuts, + n_adapts + n_samples; + nadapts = n_adapts, + initial_params = θ_init, + progress = false, + verbose = false, + ) + @test_throws ArgumentError AbstractMCMC.sample( + rng, + model, + nuts, + MCMCThreads(), + n_adapts + n_samples, + 2; + nadapts = n_adapts, + initial_params = θ_init, + progress = false, + verbose = false, + ) + # Transform back to original space. # NOTE: We're not correcting for the `logabsdetjac` here since, but # we're only interested in the mean it doesn't matter. diff --git a/test/mcmcchains.jl b/test/mcmcchains.jl index da470842..1f868578 100644 --- a/test/mcmcchains.jl +++ b/test/mcmcchains.jl @@ -23,7 +23,7 @@ using Statistics: mean model, sampler, n_adapts + n_samples; - nadapts = n_adapts, + n_adapts = n_adapts, initial_params = θ_init, chain_type = Chains, progress = false,