From d6f0644c4856f3fe0ac293f70b4863d1b7a609d0 Mon Sep 17 00:00:00 2001 From: jaimerzp Date: Wed, 21 Feb 2024 16:22:13 +0000 Subject: [PATCH] no mormalise --- src/sampler.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index 134b5ca..00b1f32 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -65,8 +65,7 @@ function Random_unit_vector(rng::AbstractRNG, x::AbstractVector{T}; _normalize = u = similar(x) randn!(rng, u) if _normalize - u_norm = sqrt.(sum(u.^2)) - u ./= u_norm + u ./= norm(u) end return u end @@ -75,7 +74,7 @@ function Partially_refresh_momentum(rng::AbstractRNG, nu::T, u::AbstractVector{T d = length(u) z = nu .* Random_unit_vector(rng, u; _normalize = false) uu = u .+ z - return normalize(uu) + return uu ./ norm(uu) end function Update_momentum(d::Int, eff_eps::T, g::AbstractVector{T}, u::AbstractVector{T}) where {T} @@ -89,7 +88,7 @@ function Update_momentum(d::Int, eff_eps::T, g::AbstractVector{T}, u::AbstractVe zeta = exp(-delta) uu = e .* ((1 - zeta) * (1 + zeta + ue * (1 - zeta))) + (2 * zeta) .* u delta_r = delta - log(2) + log(1 + ue + (1 - ue) * zeta^2) - return normalize(uu), delta_r + return uu ./ norm(uu), delta_r end struct MCHMCState{T}