Skip to content

Commit

Permalink
l-curve method fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
aris committed Oct 21, 2024
1 parent f67dd00 commit e9e6ad0
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 87 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NMRInversions"
uuid = "55c20db2-0166-4687-95c3-62a9c7afb29b"
authors = ["Aristarchos Mavridis <[email protected]>"]
version = "0.9.0"
version = "0.9.1"

[deps]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/types_structs.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ optim_nnls
## Finding optimal alpha
These are methods for finding the optimal regularization parameter.
They can be used as input to the `invert` function as the 'alpha' argument
(e.g., `invert(data, alpha=gcv)` ).
(e.g., `invert(data, alpha=gcv)` or `invert(data, alpha=lcurve(0.001,1,64))` ).
If you'd like to use a particular value of alpha,
you can just use that number instead (`invert(data, alpha=1`).
you can just use that number instead (`invert(data, alpha=1)`).
```@docs
gcv
lcurve
Expand Down
11 changes: 9 additions & 2 deletions ext/gui_ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ function Makie.plot(res::NMRInversions.inv_out_2D)
clearb = Button(gui[2, 10:15]; label="Clear current selection")
resetb = Button(gui[1, 15:19]; label="Reset everything")
filterb = Button(gui[2, 15:19]; label="Filter-out unselected")
saveb = Button(gui[3, 15:19]; label="Save plot (WIP)")
saveb = Button(gui[3, 10:19]; label="Save and exit")

# Title textbox
tb = Textbox(gui[1, 2:7], placeholder="Insert a title for the plot, then press enter.", width=300, reset_on_defocus=true)
Expand Down Expand Up @@ -425,7 +425,14 @@ function Makie.plot(res::NMRInversions.inv_out_2D)
ttl = tb.stored_string[]
end

f = plot(res, title=ttl)
f = plot(res, ttl)
savedir = NMRInversions.save_file(ttl, filterlist = "png")

if savedir == ""
display("Please enter a name for your file on the file dialog.")
else
save(savedir, f)
end

end

Expand Down
50 changes: 19 additions & 31 deletions src/finding_alpha.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,30 @@ function solve_gcv(svds::svd_kernel_struct, solver::Union{regularization_solver,
return f, r, α
end



"""
Compute the curvature of the L-curve at a given point.
(Hansen 2010 page 92-93)
- `f` : solution vector
- `r` : residuals
- `α` : smoothing term
- `A` : Augmented kernel matrix (`K` and `αI` stacked vertically)
- `b` : Augmented residuals (`r` and `0` stacked vertically)
"""
function l_curvature(f, r, α, A)
function l_curvature(f, r, α, A, b)

ξ = f'f
ρ = r'r
λ = α

z = solve_nnls(A, r)
z = NMRInversions.solve_ls(A, b)

∂ξ∂λ = 4 / λ * f'z
∂ξ∂λ = (4 / λ) * f'z

= (2ξ * ρ / ∂ξ∂λ) ** ∂ξ∂λ * ρ + 2 * ξ * λ * ρ + λ^4 * ξ * ∂ξ∂λ) /* ξ^2 + ρ^2)^(3 / 2)
= 2 * * ρ / ∂ξ∂λ) ** ∂ξ∂λ * ρ + 2 * ξ * λ * ρ + λ^4 * ξ * ∂ξ∂λ) / ((α * ξ^2 + ρ^2)^(3 / 2))

return

Expand All @@ -108,45 +112,29 @@ end
"""
"""
function solve_l_curve(K, g, lower, upper, n)
function solve_l_curve(K, g, solver, lower, upper, n)

alphas = exp10.(range(log10(lower), log10(upper), n))
curvatures = zeros(length(alphas))

ξarray = zeros(length(alphas))
ρarray = zeros(length(alphas))

for (i, α) in enumerate(alphas)
A = sparse([K; (α) .* NMRInversions.Γ(size(K, 2), order)])

f = vec(nonneg_lsq(A, [y; zeros(size(A, 1) - size(y, 1))], alg=:nnls))
r = K * f - y
display("Testing α = $(round(α,sigdigits=3))")

ξ = f'f
ρ = r'r
f, r = NMRInversions.solve_regularization(K, g, α, solver)

λ = α
A = sparse([K; (α) * LinearAlgebra.I ])
b = sparse([r; zeros(size(A, 1) - size(r, 1))])

ξarray[i] = ξ
ρarray[i] = ρ
c = l_curvature(f, r, α, A, b)

z = vec(nonneg_lsq(A, [r; zeros(size(A, 1) - size(r, 1))], alg=:nnls))
curvatures[i] = c

∂ξ∂λ = (4 / λ) * f'z

= (2ξ * ρ / ∂ξ∂λ) ** ∂ξ∂λ * ρ + 2 * ξ * λ * ρ + λ^4 * ξ * ∂ξ∂λ) /* ξ^2 + ρ^2)^(3 / 2)
end

curvatures[i] =
α = alphas[argmin(curvatures)]
display("The optimal α is $(round(α,sigdigits=3))")

end
# plot(ρarray, ξarray, xscale=:log10, yscale=:log10)

non_inf_indx = findall(!isinf, curvatures)
argmax(curvatures[non_inf_indx])
α = alphas[non_inf_indx][argmax(curvatures[non_inf_indx])]
A = sparse([K; (α) .* NMRInversions.Γ(size(K, 2), order)])
f = vec(nonneg_lsq(A, [y; zeros(size(A, 1) - size(y, 1))], alg=:nnls))
r = K * f - y
f, r = NMRInversions.solve_regularization(K, g, α, solver)

return f, r, α

Expand Down
17 changes: 12 additions & 5 deletions src/inversion_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,25 @@ function invert(seq::Type{<:pulse_sequence1D}, x::AbstractArray, y::Vector;
X .= X .* 1e9
end

α = 1 #placeholder, will be replaced below
ker_struct = create_kernel(seq, x, X, y)
α = 1.0 #placeholder, will be replaced below

if isa(alpha, Real)

α = alpha
ker_struct = create_kernel(seq, x, X, y)

f, r = solve_regularization(ker_struct.K, ker_struct.g, α, solver)

elseif alpha == gcv

ker_struct = create_kernel(seq, x, X, y)
f, r, α = solve_gcv(ker_struct, solver)

elseif isa(alpha, lcurve)
f, r, α = solve_l_curve(ker_struct.K, ker_struct.g, solver,
alpha.lowest_value, alpha.highest_value, alpha.number_of_steps)

else
error("alpha must be a real number or gcv")
error("alpha must be a real number or a smoothing_optimizer type.")

end

Expand Down Expand Up @@ -149,8 +152,12 @@ function invert(
elseif alpha == gcv
f, r, α = solve_gcv(ker_struct, solver)

elseif isa(alpha, lcurve)
f, r, α = solve_l_curve(ker_struct.K, ker_struct.g, solver,
alpha.lowest_value, alpha.highest_value, alpha.number_of_steps)

else
error("alpha must be a real number or gcv")
error("alpha must be a real number or a smoothing_optimizer type.")

end

Expand Down
25 changes: 22 additions & 3 deletions src/optim_regularizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,35 @@ function solve_regularization(K::AbstractMatrix, g::AbstractVector, α::Real, so
return f, r
end


"""
Solve a least squares problem, with nonnegativity constraints.
"""
function solve_nnls(A::AbstractMatrix, b::AbstractVector)

optf = Optimization.OptimizationFunction(obj_f, Optimization.AutoForwardDiff())
prob = Optimization.OptimizationProblem(optf, ones(size(A, 2)), (A, b), lb=zeros(size(A, 2)), ub=Inf * ones(size(A, 2)))
x = OptimizationOptimJL.solve(prob, OptimizationOptimJL.LBFGS(), maxiters=5000, maxtime=100)

return x
end
function obj_f(x, p)
return sum((p[1] * x - p[2]).^2)
end


function solve_nnls(A::AbstractMatrix, b::AbstractVector)
"""
Solve a least squares problem, without nonnegativity constraints.
"""
function solve_ls(A::AbstractMatrix, b::AbstractVector)

optf = Optimization.OptimizationFunction(obj_f, Optimization.AutoForwardDiff())
prob = Optimization.OptimizationProblem(optf, ones(size(A, 2)), (A, b), lb=zeros(size(A, 2)), ub=Inf * ones(size(A, 2)))
optf = Optimization.OptimizationFunction(obj_ls, Optimization.AutoForwardDiff())
prob = Optimization.OptimizationProblem(optf, ones(size(A, 2)), (A, b),
lb= -Inf .* ones(size(A, 2)), ub=Inf .* ones(size(A, 2)))
x = OptimizationOptimJL.solve(prob, OptimizationOptimJL.LBFGS(), maxiters=5000, maxtime=100)

return x
end
function obj_ls(x, p)
return sum((p[1] * x - p[2]).^2)
end
27 changes: 22 additions & 5 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ Simple non-negative least squares method for tikhonov (L2) regularization,
implemented using OptimizationOptimJl.
All around effective, but can be slow for large problems, such as 2D inversions.
It can be used as a "solver" for invert function.
Order determines the tikhonov matrix. If 0 is chosen, the identity matrix is used.
Order is an integer that determines the tikhonov matrix
(for more info look Hansen's 2010 book on inverse problems).
Order `n` means that the penalty term will be the n'th derivative
of the results.
"""
struct optim_nnls <: regularization_solver
order::Int
Expand Down Expand Up @@ -118,12 +121,26 @@ struct gcv <: smoothing_optimizer end


"""
lcurve
L curve method for finding the optimal regularization parameter α.
lcurve(a,b,n)
L-curve method for finding the optimal regularization parameter α.
It will test a set of logarithmically-spaced values,
starting from `a`, ending in `b`, and consisting of `n` values.
The optimal value will be chosen based on the maximum curvature of the L curve,
as described in Hansen 2010, "Discrete Inverse Problems".
This is usually less reliable than the `gcv` method,
but it's nice to have multiple options.
Note that the first time you use this in every session,
it will take about a minute to compile the code.
This might be optimized in the future, if there's demand for it.
STILL UNDER DEVELOPMENT!
"""
struct lcurve <: smoothing_optimizer end
struct lcurve <: smoothing_optimizer
lowest_value::Real
highest_value::Real
number_of_steps::Int
end

export smoothing_optimizer, gcv, lcurve

Expand Down
66 changes: 28 additions & 38 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using NMRInversions
#=using Plots=#
using SparseArrays
using LinearAlgebra
using Test
using Optimization, OptimizationOptimJL
Expand All @@ -22,52 +24,43 @@ function test1D(seq::Type{<:pulse_sequence1D})
end




function test_lcurve()
# @time begin

# x = exp10.(range(log10(1e-4), log10(5), 32)) # acquisition range
x = collect(range(0.01, 2, 32))
x = collect(range(0.01, 1, 32))
X = exp10.(range(-5, 1, 128)) # T range
# K = create_kernel(IR, x, X)
K = create_kernel(CPMG, x, X)
f_custom = [0.5exp.(-(x)^2 / 3) + exp.(-(x - 1.3)^2 / 0.5) for x in range(-5, 5, length(X))]

g = K * f_custom
noise_level = 0.001 * maximum(g)
y = g + noise_level .* randn(length(x))

alphas = exp10.(range(log10(1e-5), log10(1), 128))
#=data = input1D(CPMG, x, y)=#
#=plot(invert(data,alpha=lcurve(1e-5,1,64)))=#

alphas = exp10.(range(log10(1e-5), log10(1e-1), 128))
curvatures = zeros(length(alphas))
xis = zeros(length(alphas))
rhos = zeros(length(alphas))
order = 0

U, s, V = svd(K)
s_keep_ind = findall(x -> x > noise_level, s)
U = U[:, s_keep_ind]
s = s[s_keep_ind]
V = V[:, s_keep_ind]
K = U * Diagonal(s) * V'

for (i, α) in enumerate(alphas)
A = sparse([K; (α) .* NMRInversions.Γ(size(K, 2), order)])
println(α)
println("α = ", α)

# f = vec(nonneg_lsq(A, [y; zeros(size(A, 1) - size(y, 1))], alg=:nnls))
# r = K * f - y
f, r = NMRInversions.solve_regularization(K, y, α, brd, 0)
f, r = NMRInversions.solve_regularization(K, y, α, brd)

ξ = f'f
ρ = r'r
λ = α

# z = vec(nonneg_lsq(A, [r; zeros(size(A, 1) - size(r, 1))], alg=:nnls))
f, _ = NMRInversions.solve_regularization(K, r, α, brd, 0)
A = sparse([K; (α) * LinearAlgebra.I ])
b = sparse([r; zeros(size(A, 1) - size(r, 1))])

fᵢ = s .^ 2 ./ (s .^ 2 .+ α)
βᵢ = U' * y
∂ξ∂λ = -(4 / λ) * sum((1 .- fᵢ) .* fᵢ .^ 2 .* (βᵢ .^ 2 ./ s .^ 2))
# ∂ξ∂λ = (4 / λ) * f'z
z = NMRInversions.solve_ls(A, b)

∂ξ∂λ = (4 / λ) * f'z

= 2 ** ρ / ∂ξ∂λ) ** ∂ξ∂λ * ρ + 2 * ξ * λ * ρ + λ^4 * ξ * ∂ξ∂λ) / ((α * ξ^2 + ρ^2)^(3 / 2))

Expand All @@ -77,26 +70,23 @@ function test_lcurve()

end

non_inf_indx = findall(!isinf, curvatures)

α = alphas[non_inf_indx][argmax(curvatures[non_inf_indx])]
α = alphas[argmin(curvatures)]

A = sparse([K; (α) .* NMRInversions.Γ(size(K, 2), order)])
f, r = NMRInversions.solve_regularization(K, y, α, brd, 0)
# f = vec(nonneg_lsq(A, [y; zeros(size(A, 1) - size(y, 1))], alg=:nnls))
f, r = NMRInversions.solve_regularization(K, y, α, brd)

p1 = plot(alphas, curvatures, xscale=:log10)
p1 = vline!(p1, [α], label="α = ")
p2 = plot(X, [f_custom, f], label=["original" "solution"], xscale=:log10)
p3 = scatter(rhos, xis, xscale=:log10, yscale=:log10)
p3 = scatter!([rhos[argmax(curvatures[non_inf_indx])]], [xis[argmax(curvatures[non_inf_indx])]], label="α = ")
p4 = scatter(x, y, label="data")
p4 = plot!(x, K * f, label="solution")
plot(p1, p2, p3, p4)
begin
p1 = plot(alphas, curvatures, xscale=:log10, xlabel="α", ylabel="curvature",label = "curvature vs. α");
p1 = vline!(p1, [α], label="α = ");
p2 = plot(X, [f_custom, f], label=["original" "solution"], xscale=:log10);
p3 = plot(rhos, xis, xscale=:log10, yscale=:log10,label = "lcurve");
p3 = scatter!([rhos[argmin(curvatures)]], [xis[argmin(curvatures)]], label="α = ");
p4 = scatter(x, y, label="data");
p4 = plot!(x, K * f, label="solution");
plot(p1, p2, p3, p4)
end

end


function testT1T2()

x_direct = exp10.(range(log10(1e-4), log10(5), 1024)) # acquisition range
Expand Down

2 comments on commit e9e6ad0

@aris-mav
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/117770

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.1 -m "<description of version>" e9e6ad0301286f062763191ac9ae81ac7235cd6b
git push origin v0.9.1

Please sign in to comment.