From 5cf459a0bf5c092b1839be6030f2544b0d6bb260 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Wed, 11 Sep 2024 19:19:31 -0400 Subject: [PATCH] separate out fixed parameter and dataloader cases explictly for now --- Project.toml | 1 + lib/OptimizationOptimJL/test/runtests.jl | 2 +- lib/OptimizationOptimisers/Project.toml | 1 + .../src/OptimizationOptimisers.jl | 19 ++++++++++----- src/sophia.jl | 23 ++++++++++++++----- test/minibatch.jl | 2 +- 6 files changed, 34 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index d204b8fe4..26d959d80 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ LBFGSB = "5be7bae1-8223-5378-bac3-9e7378a2f6e6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" diff --git a/lib/OptimizationOptimJL/test/runtests.jl b/lib/OptimizationOptimJL/test/runtests.jl index f43bfca1a..545d96f71 100644 --- a/lib/OptimizationOptimJL/test/runtests.jl +++ b/lib/OptimizationOptimJL/test/runtests.jl @@ -42,7 +42,7 @@ end b = 0.5)); callback = CallbackTester(length(x0))) @test 10 * sol.objective < l1 - f = OptimizationFunction(rosenbrock, SecondOrder(AutoForwardDiff(), AutoZygote())) + f = OptimizationFunction(rosenbrock, AutoReverseDiff()) Random.seed!(1234) prob = OptimizationProblem(f, x0, _p, lb = [-1.0, -1.0], ub = [0.8, 0.8]) diff --git a/lib/OptimizationOptimisers/Project.toml b/lib/OptimizationOptimisers/Project.toml index a0468b426..bdae71df9 100644 --- a/lib/OptimizationOptimisers/Project.toml +++ b/lib/OptimizationOptimisers/Project.toml @@ -4,6 +4,7 @@ authors = ["Vaibhav Dixit and contributors"] version = "0.2.1" [deps] +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 001a2dac6..daa7399d3 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -2,7 +2,7 @@ module OptimizationOptimisers using Reexport, Printf, ProgressLogging @reexport using Optimisers, Optimization -using Optimization.SciMLBase +using Optimization.SciMLBase, MLUtils SciMLBase.supports_opt_cache_interface(opt::AbstractRule) = true SciMLBase.requiresgradient(opt::AbstractRule) = true @@ -57,10 +57,12 @@ function SciMLBase.__solve(cache::OptimizationCache{ throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg.")) end - if cache.p == SciMLBase.NullParameters() - data = OptimizationBase.DEFAULT_DATA - else + if cache.p isa MLUtils.DataLoader data = cache.p + dataiterate = true + else + data = [cache.p] + dataiterate = false end opt = cache.opt θ = copy(cache.u0) @@ -77,11 +79,16 @@ function SciMLBase.__solve(cache::OptimizationCache{ Optimization.@withprogress cache.progress name="Training" begin for _ in 1:maxiters for (i, d) in enumerate(data) - if cache.f.fg !== nothing + if cache.f.fg !== nothing && dataiterate x = cache.f.fg(G, θ, d) - else + elseif dataiterate cache.f.grad(G, θ, d) x = cache.f(θ, d) + elseif cache.f.fg !== nothing + x = cache.f.fg(G, θ) + else + cache.f.grad(G, θ) + x = cache.f(θ) end opt_state = Optimization.OptimizationState(iter = i, u = θ, diff --git a/src/sophia.jl b/src/sophia.jl index cd17e0f69..2bf602ce8 100644 --- a/src/sophia.jl +++ b/src/sophia.jl @@ -64,10 +64,12 @@ function SciMLBase.__solve(cache::OptimizationCache{ maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters) - if cache.p == SciMLBase.NullParameters() - data = OptimizationBase.DEFAULT_DATA - else + if cache.p isa MLUtils.DataLoader data = cache.p + dataiterate = true + else + data = [cache.p] + dataiterate = false end f = cache.f @@ -77,14 +79,23 @@ function SciMLBase.__solve(cache::OptimizationCache{ hₜ = zero(θ) for _ in 1:maxiters for (i, d) in enumerate(data) - f.grad(gₜ, θ, d) - x = cache.f(θ, d) + if cache.f.fg !== nothing && dataiterate + x = cache.f.fg(G, θ, d) + elseif dataiterate + cache.f.grad(G, θ, d) + x = cache.f(θ, d) + elseif cache.f.fg !== nothing + x = cache.f.fg(G, θ) + else + cache.f.grad(G, θ) + x = cache.f(θ) + end opt_state = Optimization.OptimizationState(; iter = i, u = θ, objective = first(x), grad = gₜ, original = nothing) - cb_call = cache.callback(θ, x...) + cb_call = cache.callback(opt_state, x...) if !(cb_call isa Bool) error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.") elseif cb_call diff --git a/test/minibatch.jl b/test/minibatch.jl index f818f4ee1..aea533a95 100644 --- a/test/minibatch.jl +++ b/test/minibatch.jl @@ -19,7 +19,7 @@ function dudt_(u, p, t) ann(u, p, st)[1] .* u end -function callback(state, l) #callback function to observe training +function callback(state, l, pred) #callback function to observe training display(l) return false end