From 47a2481c393fe1115f378b5401c486025459b087 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Sun, 27 Oct 2024 11:51:25 -0400 Subject: [PATCH 1/5] MOI vector lambda and iteration fixes in Optimisers --- lib/OptimizationMOI/src/nlp.jl | 2 +- .../src/OptimizationOptimisers.jl | 53 ++++++++++++------- lib/OptimizationOptimisers/test/runtests.jl | 7 +++ src/sophia.jl | 3 +- 4 files changed, 43 insertions(+), 22 deletions(-) diff --git a/lib/OptimizationMOI/src/nlp.jl b/lib/OptimizationMOI/src/nlp.jl index dbfb80089..b745c0aa0 100644 --- a/lib/OptimizationMOI/src/nlp.jl +++ b/lib/OptimizationMOI/src/nlp.jl @@ -375,7 +375,7 @@ function MOI.eval_hessian_lagrangian(evaluator::MOIOptimizationNLPEvaluator{T}, σ, μ) where {T} if evaluator.f.lag_h !== nothing - evaluator.f.lag_h(h, x, σ, μ) + evaluator.f.lag_h(h, x, σ, Vector(μ)) return end if evaluator.f.hess === nothing diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 99743d24d..f4db587b0 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -42,27 +42,27 @@ function SciMLBase.__solve(cache::OptimizationCache{ P, C } - maxiters = if cache.solver_args.epochs === nothing + if OptimizationBase.isa_dataiterator(cache.p) + data = cache.p + dataiterate = true + else + data = [cache.p] + dataiterate = false + end + + epochs = if cache.solver_args.epochs === nothing if cache.solver_args.maxiters === nothing - throw(ArgumentError("The number of epochs must be specified with either the epochs or maxiters kwarg.")) + throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data).")) else - cache.solver_args.maxiters + cache.solver_args.maxiters / length(data) end else cache.solver_args.epochs end - maxiters = Optimization._check_and_convert_maxiters(maxiters) - if maxiters === nothing - throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg.")) - end - - if OptimizationBase.isa_dataiterator(cache.p) - data = cache.p - dataiterate = true - else - data = [cache.p] - dataiterate = false + epochs = Optimization._check_and_convert_maxiters(epochs) + if epochs === nothing + throw(ArgumentError("The number of iterations must be specified with either the epochs or maxiters kwarg. Where maxiters = epochs*length(data).")) end opt = cache.opt @@ -75,21 +75,35 @@ function SciMLBase.__solve(cache::OptimizationCache{ min_θ = cache.u0 state = Optimisers.setup(opt, θ) - + iterations = 0 + fevals = 0 + gevals = 0 t0 = time() Optimization.@withprogress cache.progress name="Training" begin - for epoch in 1:maxiters + for epoch in 1:epochs for (i, d) in enumerate(data) if cache.f.fg !== nothing && dataiterate x = cache.f.fg(G, θ, d) + iterations += 1 + fevals += 1 + gevals += 1 elseif dataiterate cache.f.grad(G, θ, d) x = cache.f(θ, d) + iterations += 1 + fevals += 2 + gevals += 1 elseif cache.f.fg !== nothing x = cache.f.fg(G, θ) + iterations += 1 + fevals += 1 + gevals += 1 else cache.f.grad(G, θ) x = cache.f(θ) + iterations += 1 + fevals += 2 + gevals += 1 end opt_state = Optimization.OptimizationState( iter = i + (epoch - 1) * length(data), @@ -112,7 +126,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ min_err = x min_θ = copy(θ) end - if i == maxiters #Last iter, revert to best. + if i == length(data) #Last iter, revert to best. opt = min_opt x = min_err θ = min_θ @@ -132,10 +146,9 @@ function SciMLBase.__solve(cache::OptimizationCache{ end t1 = time() - stats = Optimization.OptimizationStats(; iterations = maxiters, - time = t1 - t0, fevals = maxiters, gevals = maxiters) + stats = Optimization.OptimizationStats(; iterations, + time = t1 - t0, fevals, gevals) SciMLBase.build_solution(cache, cache.opt, θ, first(x)[1], stats = stats) - # here should be build_solution to create the output message end end diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index 12b6f2754..4728cbf25 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -27,6 +27,9 @@ using Zygote sol = solve(prob, Optimisers.Adam(), maxiters = 1000) @test 10 * sol.objective < l1 + @test sol.stats.iterations == 1000 + @test sol.stats.fevals == 1000 + @test sol.stats.gevals == 1000 @testset "cache" begin objective(x, p) = (p[1] - x[1])^2 @@ -99,6 +102,10 @@ end res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000) @test res.objective < 1e-4 + @test res.stats.iterations == 10000*length(data) + @test res.stats.fevals == 10000*length(data) + @test res.stats.gevals == 10000*length(data) + using MLDataDevices data = CPUDevice()(data) diff --git a/src/sophia.jl b/src/sophia.jl index b63f0c099..9f4d973e9 100644 --- a/src/sophia.jl +++ b/src/sophia.jl @@ -88,7 +88,8 @@ function SciMLBase.__solve(cache::OptimizationCache{ cache.f.grad(gₜ, θ) x = cache.f(θ) end - opt_state = Optimization.OptimizationState(; iter = i + (epoch - 1) * length(data), + opt_state = Optimization.OptimizationState(; + iter = i + (epoch - 1) * length(data), u = θ, objective = first(x), grad = gₜ, From 4131720dbf236091c3aec0bbf51f3175ea043fa6 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Sun, 27 Oct 2024 13:57:43 -0400 Subject: [PATCH 2/5] fix nested breaking and iteration counts --- .../src/OptimizationOptimisers.jl | 14 +++++++++----- lib/OptimizationOptimisers/test/runtests.jl | 13 ++++++++----- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index f4db587b0..da2299e01 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -79,8 +79,12 @@ function SciMLBase.__solve(cache::OptimizationCache{ fevals = 0 gevals = 0 t0 = time() + breakall = false Optimization.@withprogress cache.progress name="Training" begin for epoch in 1:epochs + if breakall + break + end for (i, d) in enumerate(data) if cache.f.fg !== nothing && dataiterate x = cache.f.fg(G, θ, d) @@ -111,10 +115,10 @@ function SciMLBase.__solve(cache::OptimizationCache{ objective = x[1], grad = G, original = state) - cb_call = cache.callback(opt_state, x...) - if !(cb_call isa Bool) + breakall = cache.callback(opt_state, x...) + if !(breakall isa Bool) error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.") - elseif cb_call + elseif breakall break end msg = @sprintf("loss: %.3g", first(x)[1]) @@ -126,7 +130,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ min_err = x min_θ = copy(θ) end - if i == length(data) #Last iter, revert to best. + if i == length(data)*epochs #Last iter, revert to best. opt = min_opt x = min_err θ = min_θ @@ -136,7 +140,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ objective = x[1], grad = G, original = state) - cache.callback(opt_state, x...) + breakall = cache.callback(opt_state, x...) break end end diff --git a/lib/OptimizationOptimisers/test/runtests.jl b/lib/OptimizationOptimisers/test/runtests.jl index 4728cbf25..953db8960 100644 --- a/lib/OptimizationOptimisers/test/runtests.jl +++ b/lib/OptimizationOptimisers/test/runtests.jl @@ -76,7 +76,7 @@ end using Optimization, OptimizationOptimisers, Lux, Zygote, MLUtils, Random, ComponentArrays - x = rand(10000) + x = rand(Float32, 10000) y = sin.(x) data = MLUtils.DataLoader((x, y), batchsize = 100) @@ -99,13 +99,16 @@ end optf = OptimizationFunction(loss, AutoZygote()) prob = OptimizationProblem(optf, ps_ca, data) - res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 10000) + res = Optimization.solve(prob, Optimisers.Adam(), epochs = 50) @test res.objective < 1e-4 - @test res.stats.iterations == 10000*length(data) - @test res.stats.fevals == 10000*length(data) - @test res.stats.gevals == 10000*length(data) + @test res.stats.iterations == 50*length(data) + @test res.stats.fevals == 50*length(data) + @test res.stats.gevals == 50*length(data) + + res = Optimization.solve(prob, Optimisers.Adam(), callback = callback, epochs = 100) + @test res.objective < 1e-4 using MLDataDevices data = CPUDevice()(data) From 2e63a8e6555df0257ccc49aa409bdc675c768ce1 Mon Sep 17 00:00:00 2001 From: Vaibhav Kumar Dixit Date: Sun, 27 Oct 2024 14:59:20 -0400 Subject: [PATCH 3/5] Update Project.toml --- lib/OptimizationMOI/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/OptimizationMOI/Project.toml b/lib/OptimizationMOI/Project.toml index 5df31b4a0..6069c73a1 100644 --- a/lib/OptimizationMOI/Project.toml +++ b/lib/OptimizationMOI/Project.toml @@ -1,7 +1,7 @@ name = "OptimizationMOI" uuid = "fd9f6733-72f4-499f-8506-86b2bdd0dea1" authors = ["Vaibhav Dixit and contributors"] -version = "0.5.0" +version = "0.5.1" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From 1d0aa19ad8f223f71370d60ed8f14c35af78ed8a Mon Sep 17 00:00:00 2001 From: Vaibhav Kumar Dixit Date: Sun, 27 Oct 2024 14:59:34 -0400 Subject: [PATCH 4/5] Update Project.toml --- lib/OptimizationOptimisers/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index b0e763c2f..f2070226f 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -1,7 +1,7 @@ name = "OptimizationOptimisers" uuid = "42dfb2eb-d2b4-4451-abcd-913932933ac1" authors = ["Vaibhav Dixit and contributors"] -version = "0.3.3" +version = "0.3.4" [deps] Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" From 3d02908f2e66695f2b66790acb8730e14efe3922 Mon Sep 17 00:00:00 2001 From: Vaibhav Kumar Dixit Date: Sun, 27 Oct 2024 14:59:47 -0400 Subject: [PATCH 5/5] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e484ab75c..7db077da5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Optimization" uuid = "7f7a1694-90dd-40f0-9382-eb1efda571ba" -version = "4.0.3" +version = "4.0.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"