Skip to content

Commit

Permalink
even closer but Hyperparams are defined weirdly
Browse files Browse the repository at this point in the history
  • Loading branch information
JaimeRZP committed Feb 19, 2024
1 parent 3769e99 commit 8e79583
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
13 changes: 7 additions & 6 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ mutable struct Hyperparameters{T}
sigma_xi::T
end

function Hyperparameters(eps::T, L::T, sigma::Vector{T}; kwargs...) where {T}
Hyperparameters(T::Type; kwargs...) = begin
eps = get(kwargs, :eps, T(0.0))
L = get(kwargs, :L, T(0.0))
nu = get(kwargs, :nu, T(0.0))
sigma = get(kwargs, :sigma, [T(0.0)])
lambda_c = get(kwargs, :lambda_c, T(0.1931833275037836))
gamma = get(kwargs, :gamma, T((50 - 1) / (50 + 1))) #(neff-1)/(neff+1)
sigma_xi = get(kwargs, :sigma_xi, T(1.5))
return Hyperparameters(eps, L, nu, lambda_c, sigma, gamma, sigma_xi)
Hyperparameters(eps, L, nu, lambda_c, sigma, gamma, sigma_xi)
end

struct MCHMCSampler <: AbstractMCMC.AbstractSampler
Expand All @@ -34,11 +37,9 @@ end
Constructor for the MicroCanonical HMC sampler
"""
function MCHMC(nadapt::Int, TEV::Real;
eps=0.0, L=0.0, sigma=[0.0],
integrator="LF", adaptive=false, kwargs...)
function MCHMC(nadapt::Int, TEV::Real; integrator="LF", adaptive=false, T::Type=Float64, kwargs...)
"""the MCHMC (q = 0 Hamiltonian) sampler"""
hyperparameters = Hyperparameters(eps, L, sigma; kwargs...)
hyperparameters = Hyperparameters(T; kwargs...)

### integrator ###
if integrator == "LF" # leapfrog
Expand Down
15 changes: 10 additions & 5 deletions test/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@
@test eltype(target_default.h.ℓπ(target_default.θ_start)) == Float64
@test eltype(target_T.h.ℓπ(target_T.θ_start)) == T

spl = MCHMC(0, 0.01)
spl = MCHMC(0, 0.01; T=T)
@test eltype(spl.hyperparameters.eps) == Float32
_, _, _, = MicroCanonicalHMC.tune_what(spl, 10; T=Float32)
#println(spl.hyperparameters)
#@test eltype(spl.hyperparameters.eps) == Float32
#@test eltype(spl.hyperparameters.L) == Float32
#@test eltype(spl.hyperparameters.sigma) == Float32
@test eltype(spl.hyperparameters.eps) == Float32
t, s = Step(spl, target_T.h, target_T.θ_start)
@test eltype(s.x) == T

aspl = MCHMC(0, 0.01; T=T, adaptive=true)
@test eltype(aspl.hyperparameters.eps) == Float32
_, _, _, = MicroCanonicalHMC.tune_what(spl, 10; T=Float32)
@test eltype(aspl.hyperparameters.eps) == Float32
t, s = Step(aspl, target_T.h, target_T.θ_start)
@test eltype(s.x) == T
end
@testset "Settings" begin
spl = MCHMC(
Expand Down

0 comments on commit 8e79583

Please sign in to comment.