From a883eb673e37c7700774dbe540a3a8f441fea9f7 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 31 Oct 2023 21:48:36 +0000 Subject: [PATCH] Fix some tests. (#356) * CompatHelper: add new compat entry for Statistics at version 1, (keep existing compat) (#354) Co-authored-by: CompatHelper Julia * Update constructors.jl * Update constructors.jl * Update test/constructors.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Project.toml * Update abstractmcmc.jl --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: CompatHelper Julia --- Project.toml | 1 + src/abstractmcmc.jl | 2 +- test/constructors.jl | 22 +++++++++++++--------- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index afc032c9..4fe85fe9 100644 --- a/Project.toml +++ b/Project.toml @@ -43,6 +43,7 @@ ProgressMeter = "1" Requires = "0.5, 1" Setfield = "0.7, 0.8, 1" SimpleUnPack = "1.1" +Statistics = "1.6" StatsBase = "0.31, 0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" julia = "1.6" diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index f52a90ec..4e94b6d4 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -117,7 +117,7 @@ function AbstractMCMC.step( # Define integration algorithm # Find good eps if not provided one - initial_params = make_init_params(rng, spl, logdensity, initial_params) + initial_params = make_initial_params(rng, spl, logdensity, initial_params) ϵ = make_step_size(rng, spl, hamiltonian, initial_params) integrator = make_integrator(spl, ϵ) diff --git a/test/constructors.jl b/test/constructors.jl index 49b0cf89..d490eb11 100644 --- a/test/constructors.jl +++ b/test/constructors.jl @@ -131,9 +131,13 @@ get_kernel_hyperparamsT(spl::NUTS, state) = typeof(state.κ.τ.termination_crite @test AdvancedHMC.sampler_eltype(sampler) == T # Step. - transition, state = - AbstractMCMC.step(rng, model, sampler; n_adapts = 0, init_params = θ_init) - + transition, state = AbstractMCMC.step( + rng, + model, + sampler; + n_adapts = 0, + initial_params = θ_init, + ) # Verify that the types are preserved in the transition. @test eltype(transition.z.θ) == T @test eltype(transition.z.r) == T @@ -159,7 +163,7 @@ get_kernel_hyperparamsT(spl::NUTS, state) = typeof(state.κ.τ.termination_crite end @testset "Utils" begin - @testset "init_params" begin + @testset "initial_params" begin d = 2 θ_init = randn(d) rng = Random.default_rng() @@ -171,10 +175,10 @@ end metric = AdvancedHMC.make_metric(spl, logdensity) hamiltonian = Hamiltonian(metric, model) - init_params1 = AdvancedHMC.make_init_params(rng, spl, logdensity, nothing) - @test typeof(init_params1) == Vector{T} - @test length(init_params1) == d - init_params2 = AdvancedHMC.make_init_params(rng, spl, logdensity, θ_init) - @test init_params2 == θ_init + initial_params1 = AdvancedHMC.make_initial_params(rng, spl, logdensity, nothing) + @test typeof(initial_params1) == Vector{T} + @test length(initial_params1) == d + initial_params2 = AdvancedHMC.make_initial_params(rng, spl, logdensity, θ_init) + @test initial_params2 == θ_init end end