Skip to content

Commit

Permalink
Fix some tests. (#356)
Browse files Browse the repository at this point in the history
* CompatHelper: add new compat entry for Statistics at version 1, (keep existing compat) (#354)

Co-authored-by: CompatHelper Julia <[email protected]>

* Update constructors.jl

* Update constructors.jl

* Update test/constructors.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update Project.toml

* Update abstractmcmc.jl

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: CompatHelper Julia <[email protected]>
  • Loading branch information
3 people authored Oct 31, 2023
1 parent 626c71d commit a883eb6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ ProgressMeter = "1"
Requires = "0.5, 1"
Setfield = "0.7, 0.8, 1"
SimpleUnPack = "1.1"
Statistics = "1.6"
StatsBase = "0.31, 0.32, 0.33, 0.34"
StatsFuns = "0.8, 0.9, 1"
julia = "1.6"
Expand Down
2 changes: 1 addition & 1 deletion src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ function AbstractMCMC.step(

# Define integration algorithm
# Find good eps if not provided one
initial_params = make_init_params(rng, spl, logdensity, initial_params)
initial_params = make_initial_params(rng, spl, logdensity, initial_params)
ϵ = make_step_size(rng, spl, hamiltonian, initial_params)
integrator = make_integrator(spl, ϵ)

Expand Down
22 changes: 13 additions & 9 deletions test/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,13 @@ get_kernel_hyperparamsT(spl::NUTS, state) = typeof(state.κ.τ.termination_crite
@test AdvancedHMC.sampler_eltype(sampler) == T

# Step.
transition, state =
AbstractMCMC.step(rng, model, sampler; n_adapts = 0, init_params = θ_init)

transition, state = AbstractMCMC.step(
rng,
model,
sampler;
n_adapts = 0,
initial_params = θ_init,
)
# Verify that the types are preserved in the transition.
@test eltype(transition.z.θ) == T
@test eltype(transition.z.r) == T
Expand All @@ -159,7 +163,7 @@ get_kernel_hyperparamsT(spl::NUTS, state) = typeof(state.κ.τ.termination_crite
end

@testset "Utils" begin
@testset "init_params" begin
@testset "initial_params" begin
d = 2
θ_init = randn(d)
rng = Random.default_rng()
Expand All @@ -171,10 +175,10 @@ end
metric = AdvancedHMC.make_metric(spl, logdensity)
hamiltonian = Hamiltonian(metric, model)

init_params1 = AdvancedHMC.make_init_params(rng, spl, logdensity, nothing)
@test typeof(init_params1) == Vector{T}
@test length(init_params1) == d
init_params2 = AdvancedHMC.make_init_params(rng, spl, logdensity, θ_init)
@test init_params2 == θ_init
initial_params1 = AdvancedHMC.make_initial_params(rng, spl, logdensity, nothing)
@test typeof(initial_params1) == Vector{T}
@test length(initial_params1) == d
initial_params2 = AdvancedHMC.make_initial_params(rng, spl, logdensity, θ_init)
@test initial_params2 == θ_init
end
end

0 comments on commit a883eb6

Please sign in to comment.