From 19b48105ecf99d47a15e055cc28fa31c7b38fd63 Mon Sep 17 00:00:00 2001
From: jaimerzp <jaimerz011235813@gmail.com>
Date: Wed, 21 Feb 2024 16:19:09 +0000
Subject: [PATCH] more general random vector generation

---
 src/sampler.jl | 16 +++++++---------
 test/base.jl   |  2 +-
 2 files changed, 8 insertions(+), 10 deletions(-)

diff --git a/src/sampler.jl b/src/sampler.jl
index 829e03c..134b5ca 100644
--- a/src/sampler.jl
+++ b/src/sampler.jl
@@ -60,22 +60,20 @@ function MCHMC(nadapt::Int, TEV::Real;
     return MCHMCSampler(nadapt, TEV, adaptive, tune_eps, tune_L, tune_sigma, hyperparameters, hamiltonian_dynamics)
 end
 
-function Random_unit_vector(rng::AbstractRNG, d::Int; _normalize = true)
-    return Random_unit_vector(rng, d, Float64; _normalize = _normalize)
-end
-
-function Random_unit_vector(rng::AbstractRNG, d::Int, T::Type; _normalize = true)
+function Random_unit_vector(rng::AbstractRNG, x::AbstractVector{T}; _normalize = true) where {T}
     """Generates a random (isotropic) unit vector."""
-    u = randn(rng, T, d)
+    u = similar(x)
+    randn!(rng, u)
     if _normalize
-        u = normalize(u)
+        u_norm = sqrt.(sum(u.^2))
+        u ./= u_norm
     end
     return u
 end
 
 function Partially_refresh_momentum(rng::AbstractRNG, nu::T, u::AbstractVector{T}) where {T}
     d = length(u)
-    z = nu .* Random_unit_vector(rng, d, T; _normalize = false)
+    z = nu .* Random_unit_vector(rng, u; _normalize = false)
     uu = u .+ z
     return normalize(uu)
 end
@@ -141,7 +139,7 @@ function Step(
     kwargs = Dict(kwargs)
     d = length(init_params)
     l, g = -1 .* h.∂lπ∂θ(init_params)
-    u = Random_unit_vector(rng, d, T)
+    u = Random_unit_vector(rng, init_params)
     eps = sampler.hyperparameters.eps
     Weps = T(1e-5)
     Feps = T(Weps * eps^(1 / 6))
diff --git a/test/base.jl b/test/base.jl
index e940ec6..816aefd 100644
--- a/test/base.jl
+++ b/test/base.jl
@@ -50,7 +50,7 @@ end
 @testset "Partially_refresh_momentum" begin
     d = 10
     rng = MersenneTwister(0)
-    u = MicroCanonicalHMC.Random_unit_vector(rng, d)
+    u = MicroCanonicalHMC.Random_unit_vector(rng, ones(d))
     @test length(u) == d
     @test isapprox(norm(u), 1.0, rtol = 0.0000001)