Skip to content

Commit

Permalink
swapped to optim, refactored find_alpha, tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
aris-mav committed Jan 3, 2025
1 parent 352b798 commit 984237e
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 168 deletions.
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
name = "NMRInversions"
uuid = "55c20db2-0166-4687-95c3-62a9c7afb29b"
authors = ["Aristarchos Mavridis <[email protected]>"]
version = "0.9.2"
version = "0.9.3"

[deps]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NativeFileDialog = "e1fe445b-aa65-4df4-81c1-2041507f0fd4"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PolygonOps = "647866c9-e3ac-4575-94e7-e3d426903924"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -29,8 +28,7 @@ DelimitedFiles = "1"
GLMakie = "0.10"
JuMP = "1"
NativeFileDialog = "0.2"
Optimization = "3, 4"
OptimizationOptimJL = "0.3, 0.4"
Optim = "1.10.0"
PolygonOps = "0.1"
QuadraticModels = "0.9"
RipQP = "0.6"
Expand Down
4 changes: 2 additions & 2 deletions src/NMRInversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ using LinearAlgebra
using SparseArrays
using NativeFileDialog
using PolygonOps
import Optimization, OptimizationOptimJL
using Optim

"""
to do list:
- add gcv for reci method
- differentiate between Mitchell GCV and optimization GCV
- differentiate between Mitchell GCV and optim GCV
- introduce faf, flip angle fraction, to the kernel functions. 1 would be a perfect pulse, 0 would be no pulse.
- add precompilation
Expand Down
14 changes: 7 additions & 7 deletions src/exp_fits.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

export mexp
"""
mexp(seq, u, x)
Expand Down Expand Up @@ -61,7 +60,7 @@ Arguments:
- `y` : acquisition y parameter (magnetization).
Optional arguments:
- `solver` : OptimizationOptimJL solver, defeault choice is BFGS().
- `solver` : Optim solver, defeault choice is IPNewton().
- `normalize` : Normalize the data before fitting? (default is true).
- `L` : An integer specifying which norm of the residuals you want to minimize (default is 2).
Expand Down Expand Up @@ -91,7 +90,7 @@ function expfit(
seq::Type{<:NMRInversions.pulse_sequence1D},
x::Vector,
y::Vector;
solver=OptimizationOptimJL.BFGS(),
solver= IPNewton(),
normalize::Bool=true,
L::Int = 2
)
Expand All @@ -116,10 +115,11 @@ function expfit(

end

# Solve the optimization
optf = Optimization.OptimizationFunction(mexp_loss, Optimization.AutoForwardDiff())
prob = Optimization.OptimizationProblem(optf, u0, (x, y, seq, L), lb=zeros(length(u0)), ub=Inf .* ones(length(u0)))
u = OptimizationOptimJL.solve(prob, solver, maxiters=5000, maxtime=100)
u = optimize(
u -> mexp_loss(u, (x, y, seq, L)),
zeros(length(u0)), Inf .* ones(length(u0)), u0,
solver
).minimizer

# Determine what's the x-axis of the seq (time or bfactor)
seq == NMRInversions.PFG ? x_ax = "b" : x_ax = "t"
Expand Down
149 changes: 119 additions & 30 deletions src/finding_alpha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,65 @@ function gcv_score(α, r, s, x; next_alpha=true)
end
end

function gcv_cost::Real,
svds::svd_kernel_struct,
solver::Union{regularization_solver, Type{<:regularization_solver}})

#=display("Testing α = $(round(α,sigdigits=3))")=#
f, r = solve_regularization(svds.K, svds.g, α, solver)
return gcv_score(α, r, svds.S, (svds.V' * f), next_alpha = false)
end

"""
Compute the curvature of the L-curve at a given point.
(Hansen 2010 page 92-93)
- `f` : solution vector
- `r` : residuals
- `α` : smoothing term
- `A` : Augmented kernel matrix (`K` and `αI` stacked vertically)
- `b` : Augmented residuals (`r` and `0` stacked vertically)
"""
function l_curvature(f, r, α, A, b)

ξ = f'f
ρ = r'r
λ = α

z = NMRInversions.solve_ls(A, b)

∂ξ∂λ = (4 / λ) * f'z

= 2 ** ρ / ∂ξ∂λ) ** ∂ξ∂λ * ρ + 2 * ξ * λ * ρ + λ^4 * ξ * ∂ξ∂λ) / ((α * ξ^2 + ρ^2)^(3 / 2))

return

end

function l_cost(K, g, α, solver)

display("Testing α = $(round(α,sigdigits=3))")

f, r = NMRInversions.solve_regularization(K, g, α, solver)

A = sparse([K; (α) * LinearAlgebra.I ])
b = sparse([r; zeros(size(A, 1) - size(r, 1))])

return l_curvature(f, r, α, A, b)

end




"""
Solve repeatedly until the GCV score stops decreasing.
Solve repeatedly until the GCV score stops decreasing, following Mitchell 2012 paper.
Select the solution with minimum gcv score and return it, along with the residuals.
"""
function solve_gcv(svds::svd_kernel_struct, solver::Union{regularization_solver, Type{<:regularization_solver}})
function find_alpha(svds::svd_kernel_struct,
solver::Union{regularization_solver, Type{<:regularization_solver}},
mode::gcv_mitchell)

= svds.S
= length(s̃)
Expand Down Expand Up @@ -82,29 +135,68 @@ end


"""
Compute the curvature of the L-curve at a given point.
(Hansen 2010 page 92-93)
Find alpha via univariate optimization.
"""
function find_alpha(svds::svd_kernel_struct,
solver::Union{regularization_solver, Type{<:regularization_solver}},
mode::find_alpha_univariate
)

- `f` : solution vector
- `r` : residuals
- `α` : smoothing term
- `A` : Augmented kernel matrix (`K` and `αI` stacked vertically)
- `b` : Augmented residuals (`r` and `0` stacked vertically)
local f
if mode.search_method == :gcv
f = x -> gcv_cost(x, svds, solver)

elseif mode.search_method == :lcurve
f = x -> l_cost(svds.K, svds.g, x, solver)

end

sol = optimize(
f,
mode.lower, mode.upper,
mode.algorithm,
abs_tol = mode.abs_tol
)

α = sol.minimizer
display("Converged at α =$(round(α,sigdigits=3)), after $(sol.f_calls) calls.")

f, r = NMRInversions.solve_regularization(svds.K, svds.g, α, solver)

return f, r, α

end

"""
function l_curvature(f, r, α, A, b)
Find alpha using Fminbox optimization.
"""
function find_alpha(svds::svd_kernel_struct,
solver::Union{regularization_solver, Type{<:regularization_solver}},
mode::find_alpha_box
)

ξ = f'f
ρ = r'r
λ = α
local f
if mode.search_method == :gcv
f = x -> gcv_cost(first(x), svds, solver)

z = NMRInversions.solve_ls(A, b)
elseif mode.search_method == :lcurve

∂ξ∂λ = (4 / λ) * f'z
f = x -> l_cost(svds.K, svds.g, first(x), solver)
end

= 2 ** ρ / ∂ξ∂λ) ** ∂ξ∂λ * ρ + 2 * ξ * λ * ρ + λ^4 * ξ * ∂ξ∂λ) / ((α * ξ^2 + ρ^2)^(3 / 2))
sol = optimize(
f,
[0], [Inf], [mode.start],
Fminbox(mode.algorithm),
mode.opts
)

return
α = sol.minimizer[1]

display("Converged at α =$(round(α,sigdigits=3)), after $(sol.f_calls) calls.")
f, r = NMRInversions.solve_regularization(svds.K, svds.g, α, solver)

return f, r, α

end

Expand All @@ -113,30 +205,27 @@ end
Test `n` alpha values between `lower` and `upper` and select the one
which is at the heel of the L curve, accoding to Hansen 2010.
"""
function solve_l_curve(K, g, solver, lower, upper, n)
function find_alpha(svds::svd_kernel_struct,
solver::Union{regularization_solver, Type{<:regularization_solver}},
mode::lcurve_range
)

alphas = exp10.(range(log10(mode.lowest_value),
log10(mode.highest_value),
mode.number_of_steps))

alphas = exp10.(range(log10(lower), log10(upper), n))
curvatures = zeros(length(alphas))

for (i, α) in enumerate(alphas)
display("Testing α = $(round(α,sigdigits=3))")

f, r = NMRInversions.solve_regularization(K, g, α, solver)

A = sparse([K; (α) * LinearAlgebra.I ])
b = sparse([r; zeros(size(A, 1) - size(r, 1))])

c = l_curvature(f, r, α, A, b)

curvatures[i] = c
curvatures[i] = l_cost(svds.K, svds.g, α, solver)

end

α = alphas[argmin(curvatures)]
display("The optimal α is $(round(α,sigdigits=3))")

f, r = NMRInversions.solve_regularization(K, g, α, solver)

f, r = NMRInversions.solve_regularization(svds.K, svds.g, α, solver)
return f, r, α

end
37 changes: 11 additions & 26 deletions src/inversion_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ Alternatiively, a vector of values can be used directly, if more freedom is need
"""
function invert(seq::Type{<:pulse_sequence1D}, x::AbstractArray, y::Vector;
lims::Union{Tuple{Real, Real, Int}, AbstractVector, Type{<:pulse_sequence1D}}=seq,
alpha::Union{Real, smoothing_optimizer, Type{<:smoothing_optimizer}}=gcv,
solver::Union{regularization_solver, Type{<:regularization_solver}}=brd,
normalize::Bool = true)
lims::Union{Tuple{Real, Real, Int}, AbstractVector, Type{<:pulse_sequence1D}}=seq,
alpha::Union{Real, alpha_optimizer}=gcv(),
solver::Union{regularization_solver, Type{<:regularization_solver}}=brd(),
normalize::Bool = true
)

if normalize
y = y ./ y[argmax(real(y))]
Expand All @@ -51,33 +52,24 @@ function invert(seq::Type{<:pulse_sequence1D}, x::AbstractArray, y::Vector;
end

ker_struct = create_kernel(seq, x, X, y)
α = 1.0 #placeholder, will be replaced below
α = 0.0 #placeholder, will be replaced below

if isa(alpha, Real)

α = alpha

f, r = solve_regularization(ker_struct.K, ker_struct.g, α, solver)

elseif alpha == gcv

f, r, α = solve_gcv(ker_struct, solver)
else

elseif isa(alpha, lcurve)
f, r, α = solve_l_curve(ker_struct.K, ker_struct.g, solver,
alpha.lowest_value, alpha.highest_value, alpha.number_of_steps)
f, r, α = find_alpha(ker_struct, solver, alpha)

else
error("alpha must be a real number or a smoothing_optimizer type.")

end

x_fit = exp10.(range(log10(1e-8), log10(1.1 * x[end]), 512))
y_fit = create_kernel(seq, x_fit, X) * f

isreal(y) ? SNR = NaN : SNR = calc_snr(y)


if seq == PFG
X .= X ./ 1e9
end
Expand Down Expand Up @@ -131,8 +123,8 @@ function invert(
seq::Type{<:pulse_sequence2D}, x_direct::AbstractVector, x_indirect::AbstractVector, Data::AbstractMatrix;
lims1::Union{Tuple{Real, Real, Int}, AbstractVector}=(-5, 1, 100),
lims2::Union{Tuple{Real, Real, Int}, AbstractVector}=(-5, 1, 100),
alpha::Union{Real, smoothing_optimizer, Type{<:smoothing_optimizer}}=gcv,
solver::Union{regularization_solver, Type{<:regularization_solver}}=brd,
alpha::Union{Real, alpha_optimizer} = gcv(),
solver::Union{regularization_solver, Type{<:regularization_solver}}=brd(),
normalize::Bool=true)

if normalize
Expand Down Expand Up @@ -160,15 +152,8 @@ function invert(
α = alpha
f, r = solve_regularization(ker_struct.K, ker_struct.g, α, solver)

elseif alpha == gcv
f, r, α = solve_gcv(ker_struct, solver)

elseif isa(alpha, lcurve)
f, r, α = solve_l_curve(ker_struct.K, ker_struct.g, solver,
alpha.lowest_value, alpha.highest_value, alpha.number_of_steps)

else
error("alpha must be a real number or a smoothing_optimizer type.")
f, r, α = find_alpha(ker_struct, solver, alpha)

end

Expand Down
Loading

0 comments on commit 984237e

Please sign in to comment.