Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JaimeRZP committed Jan 14, 2025
1 parent fc1b75e commit fffb69e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
41 changes: 32 additions & 9 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mutable struct MCHMCSampler <: AbstractMCMC.AbstractSampler
tune_eps::Bool
tune_L::Bool
tune_sigma::Bool
L_tuning_method::String
hyperparameters::Hyperparameters
hamiltonian_dynamics::Function
end
Expand All @@ -43,6 +44,7 @@ function MCHMC(nadapt::Int, TEV::Real;
tune_eps=true,
tune_L=true,
tune_sigma=true,
L_tuning_method="sigma",
kwargs...)

### Init Hyperparameters ###
Expand All @@ -57,7 +59,17 @@ function MCHMC(nadapt::Int, TEV::Real;
println(string("integrator = ", integrator, "is not a valid option."))
end

return MCHMCSampler(nadapt, TEV, adaptive, tune_eps, tune_L, tune_sigma, hyperparameters, hamiltonian_dynamics)
return MCHMCSampler(
nadapt,
TEV,
adaptive,
tune_eps,
tune_L,
tune_sigma,
L_tuning_method,
hyperparameters,
hamiltonian_dynamics,
)
end

function Random_unit_vector(rng::AbstractRNG, x::AbstractVector{T}; _normalize = true) where {T}
Expand Down Expand Up @@ -239,9 +251,11 @@ function Sample(
n::Int;
thinning::Int=1,
init_params = nothing,
init_state = nothing,
file_chunk=10,
fol_name = ".",
file_name = "samples",
restart = false,
include_latent = false,
kwargs...,
)
Expand All @@ -258,13 +272,18 @@ function Sample(
end
x_start = target.transform(θ_start)

transition, state = Step(
rng,
sampler,
target.h,
x_start;
inv_transform = target.inv_transform,
kwargs...)
if init_state == nothing
transition, state = Step(
rng,
sampler,
target.h,
x_start;
inv_transform = target.inv_transform,
kwargs...)
else
state = init_state
transition = Transition(state, target.inv_transform)
end

sample = _make_sample(transition; transform=target.transform, include_latent=include_latent)
samples = similar(sample, (length(sample), Int(floor(n/thinning))))
Expand Down Expand Up @@ -296,5 +315,9 @@ function Sample(

ProgressMeter.finish!(pbar)

return samples
if restart
return samples, state
else
return samples
end
end
14 changes: 7 additions & 7 deletions src/tuning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,16 @@ function tune_hyperparameters(
sampler.hyperparameters.sigma = sigma
end
if sampler.tune_L
ess, _ = Summarize(xs)
m_ess = mean(ess)
if m_ess > length(xs)/50
if sampler.L_tuning_method == "sigma"
eps = sampler.hyperparameters.eps
sampler.hyperparameters.L = sqrt(mean(sigma .^ 2)) * eps
end
if sampler.L_tuning_method == "ess"
ess, _ = Summarize(xs)
m_ess = mean(ess)
l = length(xs)/m_ess
eps = sampler.hyperparameters.eps
sampler.hyperparameters.L = 0.4*eps*l
else
@warn "Effective sample size is too low, using sigma to tune L"
sampler.hyperparameters.L =
sqrt(mean(sigma .^ 2)) * sampler.hyperparameters.eps
end
end
end
Expand Down

0 comments on commit fffb69e

Please sign in to comment.