Skip to content

Commit

Permalink
Optimization.jl interface and format
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Oct 8, 2024
1 parent df5df83 commit 849b944
Show file tree
Hide file tree
Showing 22 changed files with 433 additions and 290 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.1.0"
[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Expand All @@ -15,15 +16,15 @@ julia = "1"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"

[targets]
test = ["SparseArrays", "ForwardDiff", "Test", "SafeTestsets", "JLD2", "Random", "StaticArrays", "Statistics", "IterativeSolvers", "LinearMaps"]
93 changes: 58 additions & 35 deletions examples/lasso/FISTASolver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,48 @@ abstract type Regularizer end
# uses the Fast Iterative Soft-Thresholding Algorithm (FISTA) to
# minimize f(x) + g(x) = ½(|Ax - y|² + β|x|²) + ½α|Ψx|₁

fista(A, y, α, β, iters, tol, reg::Regularizer) = fista(A, y, α, β, iters, tol, reg::Regularizer, zeros(size(A)[2]))
function fista(A, y, α, β, iters, tol, reg::Regularizer)
fista(A, y, α, β, iters, tol, reg::Regularizer, zeros(size(A)[2]))
end

function fista(A, y, α, β, iters, tol, reg::Regularizer, xstart)
n, p = size(A)
maxeig = abs(powm(A' * A, maxiter=100)[1]) # TODO: handle number of eigiters

maxeig = abs(powm(A' * A, maxiter = 100)[1]) # TODO: handle number of eigiters
L = maxeig + β # Lipschitz constant of f (strongly convex term)
η = 1 / (L * lipschitz_scale(reg))

x = xstart[:]
z = x[:]
z = x[:]
xold = similar(x)
res = similar(y)
res = similar(y)
grad = similar(x)
t = 1

iters_done = iters
xupdates = Float64[]
xupdates = Float64[]
convdists = Float64[]
evals = Float64[]

Ψ = transform_op(reg)
Fold = Inf

for i = 1:iters

for i in 1:iters
xold .= x

res .= mul!(res, A, z) .- y # TODO: maybe use five-arg mul! instead. (but IterativeSolvers sticks to three-arg)
grad .= mul!(grad, A', res) .+ β .* z
grad .= mul!(grad, A', res) .+ β .* z

x .= z .- η .* grad
proximal!(x, 1/2 * η * α, reg)
proximal!(x, 1 / 2 * η * α, reg)

restart = dot(z .- x, x .- xold)
restart > 0 && (t = 1)

told = t
t = 1/2 * (1 + (1 + 4t^2))
z .= x .+ (told - 1)/t .* (x .- xold)
t = 1 / 2 * (1 + (1 + 4t^2))

z .= x .+ (told - 1) / t .* (x .- xold)

xupdate = norm(z .- xold) / norm(xold)
append!(xupdates, xupdate)
Expand All @@ -69,7 +70,7 @@ function fista(A, y, α, β, iters, tol, reg::Regularizer, xstart)
end
end

x, (;iters=iters_done, final_tol=norm(x .- xold) / norm(x), xupdates)
x, (; iters = iters_done, final_tol = norm(x .- xold) / norm(x), xupdates)
end

## regularizer implementations
Expand All @@ -78,12 +79,12 @@ support(x, thresh, reg::Regularizer) = abs.(transform_op(reg) * x) .> thresh
proximal(x, thresh, reg::Regularizer) = proximal!(copy(x), thresh, reg) # fallback for out-of-place proximal

## L1 regularizer
struct L1 <: Regularizer
struct L1 <: Regularizer
size::Int
end

support(x, reg::L1) = support(x, 1e-3, reg)
lipschitz_scale(::L1) = 2.
lipschitz_scale(::L1) = 2.0

# proximal

Expand Down Expand Up @@ -120,9 +121,16 @@ function L1Project(p::Int, supp::AbstractVector{Bool})
end

Base.:(*)(P::L1Project, x::AbstractVector) = x[P.suppinv .> 0]
Base.:(*)(Pt::L1ProjectTranspose, y::AbstractVector) = [Pt.lmap.suppinv[i] > 0 ? y[Pt.lmap.suppinv[i]] : 0.0 for i in 1:Pt.lmap.p] # TODO: use zero element instead?
LinearAlgebra.mul!(uflat::AbstractVecOrMat, P::L1Project, yflat::AbstractVector) = (uflat[:] = P * yflat)
LinearAlgebra.mul!(uflat::AbstractVecOrMat, Pt::L1ProjectTranspose, yflat::AbstractVector) = (uflat[:] = Pt * yflat)
function Base.:(*)(Pt::L1ProjectTranspose, y::AbstractVector)
[Pt.lmap.suppinv[i] > 0 ? y[Pt.lmap.suppinv[i]] : 0.0 for i in 1:(Pt.lmap.p)]
end # TODO: use zero element instead?
function LinearAlgebra.mul!(uflat::AbstractVecOrMat, P::L1Project, yflat::AbstractVector)
(uflat[:] = P * yflat)
end
function LinearAlgebra.mul!(
uflat::AbstractVecOrMat, Pt::L1ProjectTranspose, yflat::AbstractVector)
(uflat[:] = Pt * yflat)
end

projection_op(supp, reg::L1) = L1Project(reg.size, supp)

Expand All @@ -134,7 +142,7 @@ struct TVProximalWork
diff::AbstractArray{Float64}
end

struct TV <: Regularizer
struct TV <: Regularizer
size::Tuple
wrap::Bool
work::TVProximalWork
Expand Down Expand Up @@ -162,21 +170,23 @@ function proximal!(x, thresh, reg::TV)
xroll = roll(x, -1, i)
diff .= (d -> min(2 * D * thresh, abs(d) / 2) * sign(d)).(xroll .- x)
Δ .+= diff ./ (2 * D)
Δ .-= roll(diff, 1, i) ./ (2 * D)
Δ .-= roll(diff, 1, i) ./ (2 * D)
end
x .+= Δ
reshape(x, prod(reg.size))
end

function proximal2(x, thresh, reg::TV)
y = x[:]
proxTV!(y, thresh, shape=reg.size, iterations=20)
proxTV!(y, thresh, shape = reg.size, iterations = 20)
y
end

# transform

roll(x, shift, dim) = ShiftedArrays.circshift(x, (zeros(dim-1)..., shift, zeros(ndims(x)-dim)...))
function roll(x, shift, dim)
ShiftedArrays.circshift(x, (zeros(dim - 1)..., shift, zeros(ndims(x) - dim)...))
end

struct Gradient <: LinearMap{Float64}
size::Tuple
Expand All @@ -200,8 +210,13 @@ function Base.:(*)(Dt::GradientTranspose, xflat::AbstractVector)
y[:]
end

LinearAlgebra.mul!(uflat::AbstractVecOrMat, D::Gradient, yflat::AbstractVector) = (uflat[:] = D * yflat)
LinearAlgebra.mul!(uflat::AbstractVecOrMat, Dt::GradientTranspose, yflat::AbstractVector) = (uflat[:] = Dt * yflat)
function LinearAlgebra.mul!(uflat::AbstractVecOrMat, D::Gradient, yflat::AbstractVector)
(uflat[:] = D * yflat)
end
function LinearAlgebra.mul!(
uflat::AbstractVecOrMat, Dt::GradientTranspose, yflat::AbstractVector)
(uflat[:] = Dt * yflat)
end

function transform_op(reg::TV)
ndims = length(reg.size)
Expand Down Expand Up @@ -233,23 +248,26 @@ function TVProject(size::Tuple, supp::AbstractVector{<:Bool}, wrap::Bool)
for start_pt in CartesianIndices(size)
explored[start_pt] && continue
region = [start_pt]
explored[start_pt] = true;
s = Stack{CartesianIndex{ndims}}();
explored[start_pt] = true
s = Stack{CartesianIndex{ndims}}()
push!(s, start_pt)
while !isempty(s)
pt = pop!(s)
pttup = Tuple(pt)
for i in 1:ndims
if wrap || pt[i] < size[i]
pt_right = CartesianIndex(pttup[1:i-1]..., 1 + mod(pt[i], size[i]), pttup[i+1:ndims]...)
pt_right = CartesianIndex(pttup[1:(i - 1)]..., 1 + mod(pt[i], size[i]),
pttup[(i + 1):ndims]...)
if !supp[pt, i] && !explored[pt_right]
explored[pt_right] = true
push!(s, pt_right)
push!(region, pt_right)
end
end
if wrap || pt[i] > 1
pt_left = CartesianIndex(pttup[1:i-1]..., 1 + mod(pt[i]-2, size[i]), pttup[i+1:ndims]...)
pt_left = CartesianIndex(
pttup[1:(i - 1)]..., 1 + mod(pt[i] - 2, size[i]),
pttup[(i + 1):ndims]...)
if !supp[pt_left, i] && !explored[pt_left]
explored[pt_left] = true
push!(s, pt_left)
Expand All @@ -271,14 +289,14 @@ end

function region_vals(P::TVProject, x::AbstractVector)
x = reshape(x, P.size)
function region_vals(region)
function region_vals(region)
vals = [x[pt] for pt in region]
sum(vals) / length(region), maximum(vals) - minimum(vals), length(vals)
end
map(region_vals, P.regions)
end

function Base.:(*)(Pt::TVProjectTranspose, y::AbstractVector)
function Base.:(*)(Pt::TVProjectTranspose, y::AbstractVector)
P = Pt.lmap
x = zeros(P.size)
for (val, region) in zip(y, P.regions)
Expand All @@ -288,9 +306,14 @@ function Base.:(*)(Pt::TVProjectTranspose, y::AbstractVector)
end

projection_op(supp, reg::TV) = TVProject(reg.size, supp, reg.wrap)
LinearAlgebra.mul!(uflat::AbstractVecOrMat, P::TVProject, yflat::AbstractVector) = (uflat[:] = P * yflat)
LinearAlgebra.mul!(uflat::AbstractVecOrMat, Pt::TVProjectTranspose, yflat::AbstractVector) = (uflat[:] = Pt * yflat)
function LinearAlgebra.mul!(uflat::AbstractVecOrMat, P::TVProject, yflat::AbstractVector)
(uflat[:] = P * yflat)
end
function LinearAlgebra.mul!(
uflat::AbstractVecOrMat, Pt::TVProjectTranspose, yflat::AbstractVector)
(uflat[:] = Pt * yflat)
end

export fista, L1, TV

end
end
36 changes: 21 additions & 15 deletions examples/lasso/NLoptLassoData.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function run_once_nlopt(G, y, α, β)
n, p = size(G)

nlopt = Opt(:LD_CCSAQ, 2p)
nlopt.lower_bounds = vcat(fill(-Inf, p), zeros(p))
nlopt.lower_bounds = vcat(fill(-Inf, p), zeros(p))
nlopt.upper_bounds = fill(Inf, 2p)
# nlopt.maxeval = 2000
nlopt.xtol_rel = 1e-8
Expand All @@ -22,16 +22,16 @@ function run_once_nlopt(G, y, α, β)
nlopt.params["verbosity"] = 0

nlopt.min_objective = SetupLasso.make_obj(G, y, α, β)
for i in 1:2p
for i in 1:(2p)
inequality_constraint!(nlopt, make_cons(p, Val(i)), 1e-10) # this is helpful for convergence. needs to be tuned well: higher and lower can both be bad.
end

u0 = zeros(p)
t0 = abs.(u0) # start the t's with some slack
u0_and_t0 = vcat(u0, t0)

(minf,minx,ret) = optimize(nlopt, u0_and_t0)
return minf,minx,ret
(minf, minx, ret) = optimize(nlopt, u0_and_t0)
return minf, minx, ret
end

## NLOpt output processing
Expand All @@ -46,8 +46,10 @@ function safe_scanf(buffer, fmt, args...)
seek(buffer, pos)
# skip ignored lines
ln = readline(buffer)
if (startswith(ln, "j=") || startswith(ln, "dx =")) || startswith(ln, "u =") || startswith(ln, "v =") || startswith(ln, "dfdx") || startswith(ln, "dfcdx") || startswith(ln, "y:")
return safe_scanf(buffer, fmt, args...)
if (startswith(ln, "j=") || startswith(ln, "dx =")) || startswith(ln, "u =") ||
startswith(ln, "v =") || startswith(ln, "dfdx") || startswith(ln, "dfcdx") ||
startswith(ln, "y:")
return safe_scanf(buffer, fmt, args...)
end
seek(buffer, pos)
return nothing
Expand All @@ -60,7 +62,7 @@ end
Requires use of custom nlopt binary: build from https://github.com/gaurav-arya/nlopt/tree/ag-debug,
move shared object file to this directory, and set with set_binary.jl.
"""
function nlopt_lasso_data(evals)
function nlopt_lasso_data(evals)
open("nlopt_out.txt", "w") do io
redirect_stdout(io) do
run_once_nlopt(evals)
Expand Down Expand Up @@ -92,41 +94,45 @@ function nlopt_lasso_data(evals)

buffer = IOBuffer(read(open("nlopt_out.txt"), String)) # easier to copy
d = DataFrame()
while true
while true
# read one inner iteration
inner_history = DataFrame()
inner_iter = 0
done = false
while true
if (out = safe_scanf(buffer, inner_iter_fmt, Int64, (Float64 for i in 1:5)...)) !== nothing
if (out = safe_scanf(buffer, inner_iter_fmt, Int64, (Float64 for i in 1:5)...)) !==
nothing
dual_iters, dual_obj, dual_opt, dual_grad, _x_proposed... = out
x_proposed = collect(_x_proposed)
else
done = true
end
if (out = safe_scanf(buffer, inner_iter2_fmt, (Float64 for i in 1:2)...)) !== nothing
if (out = safe_scanf(buffer, inner_iter2_fmt, (Float64 for i in 1:2)...)) !==
nothing
ρ = collect(out)
do_break = false
else
ρ = [NaN, NaN]
do_break = true
end
push!(inner_history, (;dual_iters, dual_obj, dual_opt, dual_grad, ρ, x_proposed))
push!(
inner_history, (; dual_iters, dual_obj, dual_opt, dual_grad, ρ, x_proposed))
do_break && break
end
done && break
safe_scanf(buffer, infeasible_point_fmt, String) # skip infeasible point log in hacky way
out = safe_scanf(buffer, outer_iter_fmt, (Float64 for i in 1:2)...)
out = safe_scanf(buffer, outer_iter_fmt, (Float64 for i in 1:2)...)
if out === nothing
break
end
ρ = collect(out)
if (out = safe_scanf(buffer, outer_iter_sigma_fmt, (Float64 for i in 1:2)...)) !== nothing
if (out = safe_scanf(buffer, outer_iter_sigma_fmt, (Float64 for i in 1:2)...)) !==
nothing
σ = collect(out)
else
σ = [NaN]
end
push!(d, (;ρ, σ, inner_history, x=inner_history.x_proposed[end]))
push!(d, (; ρ, σ, inner_history, x = inner_history.x_proposed[end]))
end
if countlines(copy(buffer)) != 0
@show countlines(copy(buffer))
Expand All @@ -137,4 +143,4 @@ end

export run_once_nlopt, nlopt_lasso_data

end
end
Loading

0 comments on commit 849b944

Please sign in to comment.