Skip to content

Commit

Permalink
Merge pull request #847 from SciML/libsfixes
Browse files Browse the repository at this point in the history
MOI vector lambda and iteration fixes in Optimisers
  • Loading branch information
Vaibhavdixit02 authored Oct 27, 2024
2 parents c526d71 + 3d02908 commit 29309e8
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 31 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Optimization"
uuid = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
version = "4.0.3"
version = "4.0.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationMOI/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimizationMOI"
uuid = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "0.5.0"
version = "0.5.1"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationMOI/src/nlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ function MOI.eval_hessian_lagrangian(evaluator::MOIOptimizationNLPEvaluator{T},
σ,
μ) where {T}
if evaluator.f.lag_h !== nothing
evaluator.f.lag_h(h, x, σ, μ)
evaluator.f.lag_h(h, x, σ, Vector(μ))
return
end
if evaluator.f.hess === nothing
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationOptimisers/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimizationOptimisers"
uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "0.3.3"
version = "0.3.4"

[deps]
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Expand Down
65 changes: 41 additions & 24 deletions lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,27 @@ function SciMLBase.__solve(cache::OptimizationCache{
P,
C
}
maxiters = if cache.solver_args.epochs === nothing
if OptimizationBase.isa_dataiterator(cache.p)
data = cache.p
dataiterate = true
else
data = [cache.p]
dataiterate = false
end

epochs = if cache.solver_args.epochs === nothing
if cache.solver_args.maxiters === nothing
throw(ArgumentError("The number of epochs must be specified with either the epochs or maxiters kwarg."))
throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data)."))
else
cache.solver_args.maxiters
cache.solver_args.maxiters / length(data)
end
else
cache.solver_args.epochs
end

maxiters = Optimization._check_and_convert_maxiters(maxiters)
if maxiters === nothing
throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg."))
end

if OptimizationBase.isa_dataiterator(cache.p)
data = cache.p
dataiterate = true
else
data = [cache.p]
dataiterate = false
epochs = Optimization._check_and_convert_maxiters(epochs)
if epochs === nothing
throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data)."))
end

opt = cache.opt
Expand All @@ -75,32 +75,50 @@ function SciMLBase.__solve(cache::OptimizationCache{
min_θ = cache.u0

state = Optimisers.setup(opt, θ)

iterations = 0
fevals = 0
gevals = 0
t0 = time()
breakall = false
Optimization.@withprogress cache.progress name="Training" begin
for epoch in 1:maxiters
for epoch in 1:epochs
if breakall
break
end
for (i, d) in enumerate(data)
if cache.f.fg !== nothing && dataiterate
x = cache.f.fg(G, θ, d)
iterations += 1
fevals += 1
gevals += 1
elseif dataiterate
cache.f.grad(G, θ, d)
x = cache.f(θ, d)
iterations += 1
fevals += 2
gevals += 1
elseif cache.f.fg !== nothing
x = cache.f.fg(G, θ)
iterations += 1
fevals += 1
gevals += 1
else
cache.f.grad(G, θ)
x = cache.f(θ)
iterations += 1
fevals += 2
gevals += 1
end
opt_state = Optimization.OptimizationState(
iter = i + (epoch - 1) * length(data),
u = θ,
objective = x[1],
grad = G,
original = state)
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
breakall = cache.callback(opt_state, x...)
if !(breakall isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
elseif cb_call
elseif breakall
break
end
msg = @sprintf("loss: %.3g", first(x)[1])
Expand All @@ -112,7 +130,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
min_err = x
min_θ = copy(θ)
end
if i == maxiters #Last iter, revert to best.
if i == length(data)*epochs #Last iter, revert to best.
opt = min_opt
x = min_err
θ = min_θ
Expand All @@ -122,7 +140,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
objective = x[1],
grad = G,
original = state)
cache.callback(opt_state, x...)
breakall = cache.callback(opt_state, x...)
break
end
end
Expand All @@ -132,10 +150,9 @@ function SciMLBase.__solve(cache::OptimizationCache{
end

t1 = time()
stats = Optimization.OptimizationStats(; iterations = maxiters,
time = t1 - t0, fevals = maxiters, gevals = maxiters)
stats = Optimization.OptimizationStats(; iterations,
time = t1 - t0, fevals, gevals)
SciMLBase.build_solution(cache, cache.opt, θ, first(x)[1], stats = stats)
# here should be build_solution to create the output message
end

end
14 changes: 12 additions & 2 deletions lib/OptimizationOptimisers/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ using Zygote

sol = solve(prob, Optimisers.Adam(), maxiters = 1000)
@test 10 * sol.objective < l1
@test sol.stats.iterations == 1000
@test sol.stats.fevals == 1000
@test sol.stats.gevals == 1000

@testset "cache" begin
objective(x, p) = (p[1] - x[1])^2
Expand Down Expand Up @@ -73,7 +76,7 @@ end
using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Random,
ComponentArrays

x = rand(10000)
x = rand(Float32, 10000)
y = sin.(x)
data = MLUtils.DataLoader((x, y), batchsize = 100)

Expand All @@ -96,7 +99,14 @@ end
optf = OptimizationFunction(loss, AutoZygote())
prob = OptimizationProblem(optf, ps_ca, data)

res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000)
res = Optimization.solve(prob, Optimisers.Adam(), epochs = 50)

@test res.objective < 1e-4
@test res.stats.iterations == 50*length(data)
@test res.stats.fevals == 50*length(data)
@test res.stats.gevals == 50*length(data)

res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 100)

@test res.objective < 1e-4

Expand Down
3 changes: 2 additions & 1 deletion src/sophia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ function SciMLBase.__solve(cache::OptimizationCache{
cache.f.grad(gₜ, θ)
x = cache.f(θ)
end
opt_state = Optimization.OptimizationState(; iter = i + (epoch - 1) * length(data),
opt_state = Optimization.OptimizationState(;
iter = i + (epoch - 1) * length(data),
u = θ,
objective = first(x),
grad = gₜ,
Expand Down

0 comments on commit 29309e8

Please sign in to comment.