From 19b48105ecf99d47a15e055cc28fa31c7b38fd63 Mon Sep 17 00:00:00 2001 From: jaimerzp <jaimerz011235813@gmail.com> Date: Wed, 21 Feb 2024 16:19:09 +0000 Subject: [PATCH] more general random vector generation --- src/sampler.jl | 16 +++++++--------- test/base.jl | 2 +- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 829e03c..134b5ca 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -60,22 +60,20 @@ function MCHMC(nadapt::Int, TEV::Real; return MCHMCSampler(nadapt, TEV, adaptive, tune_eps, tune_L, tune_sigma, hyperparameters, hamiltonian_dynamics) end -function Random_unit_vector(rng::AbstractRNG, d::Int; _normalize = true) - return Random_unit_vector(rng, d, Float64; _normalize = _normalize) -end - -function Random_unit_vector(rng::AbstractRNG, d::Int, T::Type; _normalize = true) +function Random_unit_vector(rng::AbstractRNG, x::AbstractVector{T}; _normalize = true) where {T} """Generates a random (isotropic) unit vector.""" - u = randn(rng, T, d) + u = similar(x) + randn!(rng, u) if _normalize - u = normalize(u) + u_norm = sqrt.(sum(u.^2)) + u ./= u_norm end return u end function Partially_refresh_momentum(rng::AbstractRNG, nu::T, u::AbstractVector{T}) where {T} d = length(u) - z = nu .* Random_unit_vector(rng, d, T; _normalize = false) + z = nu .* Random_unit_vector(rng, u; _normalize = false) uu = u .+ z return normalize(uu) end @@ -141,7 +139,7 @@ function Step( kwargs = Dict(kwargs) d = length(init_params) l, g = -1 .* h.∂lπ∂θ(init_params) - u = Random_unit_vector(rng, d, T) + u = Random_unit_vector(rng, init_params) eps = sampler.hyperparameters.eps Weps = T(1e-5) Feps = T(Weps * eps^(1 / 6)) diff --git a/test/base.jl b/test/base.jl index e940ec6..816aefd 100644 --- a/test/base.jl +++ b/test/base.jl @@ -50,7 +50,7 @@ end @testset "Partially_refresh_momentum" begin d = 10 rng = MersenneTwister(0) - u = MicroCanonicalHMC.Random_unit_vector(rng, d) + u = MicroCanonicalHMC.Random_unit_vector(rng, ones(d)) @test length(u) == d @test isapprox(norm(u), 1.0, rtol = 0.0000001)