Skip to content

Commit

Permalink
more general random vector generation
Browse files Browse the repository at this point in the history
  • Loading branch information
JaimeRZP committed Feb 21, 2024
1 parent 3ead468 commit 19b4810
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
16 changes: 7 additions & 9 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,20 @@ function MCHMC(nadapt::Int, TEV::Real;
return MCHMCSampler(nadapt, TEV, adaptive, tune_eps, tune_L, tune_sigma, hyperparameters, hamiltonian_dynamics)
end

function Random_unit_vector(rng::AbstractRNG, d::Int; _normalize = true)
return Random_unit_vector(rng, d, Float64; _normalize = _normalize)
end

function Random_unit_vector(rng::AbstractRNG, d::Int, T::Type; _normalize = true)
function Random_unit_vector(rng::AbstractRNG, x::AbstractVector{T}; _normalize = true) where {T}
"""Generates a random (isotropic) unit vector."""
u = randn(rng, T, d)
u = similar(x)
randn!(rng, u)
if _normalize
u = normalize(u)
u_norm = sqrt.(sum(u.^2))
u ./= u_norm
end
return u
end

function Partially_refresh_momentum(rng::AbstractRNG, nu::T, u::AbstractVector{T}) where {T}
d = length(u)
z = nu .* Random_unit_vector(rng, d, T; _normalize = false)
z = nu .* Random_unit_vector(rng, u; _normalize = false)
uu = u .+ z
return normalize(uu)
end
Expand Down Expand Up @@ -141,7 +139,7 @@ function Step(
kwargs = Dict(kwargs)
d = length(init_params)
l, g = -1 .* h.∂lπ∂θ(init_params)
u = Random_unit_vector(rng, d, T)
u = Random_unit_vector(rng, init_params)
eps = sampler.hyperparameters.eps
Weps = T(1e-5)
Feps = T(Weps * eps^(1 / 6))
Expand Down
2 changes: 1 addition & 1 deletion test/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ end
@testset "Partially_refresh_momentum" begin
d = 10
rng = MersenneTwister(0)
u = MicroCanonicalHMC.Random_unit_vector(rng, d)
u = MicroCanonicalHMC.Random_unit_vector(rng, ones(d))
@test length(u) == d
@test isapprox(norm(u), 1.0, rtol = 0.0000001)

Expand Down

0 comments on commit 19b4810

Please sign in to comment.