Skip to content

Commit

Permalink
separate out fixed parameter and dataloader cases explictly for now
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Sep 11, 2024
1 parent d09cf00 commit 5cf459a
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 14 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ LBFGSB = "5be7bae1-8223-5378-bac3-9e7378a2f6e6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OptimizationBase = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Expand Down
2 changes: 1 addition & 1 deletion lib/OptimizationOptimJL/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ end
b = 0.5)); callback = CallbackTester(length(x0)))
@test 10 * sol.objective < l1

f = OptimizationFunction(rosenbrock, SecondOrder(AutoForwardDiff(), AutoZygote()))
f = OptimizationFunction(rosenbrock, AutoReverseDiff())

Random.seed!(1234)
prob = OptimizationProblem(f, x0, _p, lb = [-1.0, -1.0], ub = [0.8, 0.8])
Expand Down
1 change: 1 addition & 0 deletions lib/OptimizationOptimisers/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "0.2.1"

[deps]
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand Down
19 changes: 13 additions & 6 deletions lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module OptimizationOptimisers

using Reexport, Printf, ProgressLogging
@reexport using Optimisers, Optimization
using Optimization.SciMLBase
using Optimization.SciMLBase, MLUtils

SciMLBase.supports_opt_cache_interface(opt::AbstractRule) = true
SciMLBase.requiresgradient(opt::AbstractRule) = true
Expand Down Expand Up @@ -57,10 +57,12 @@ function SciMLBase.__solve(cache::OptimizationCache{
throw(ArgumentError("The number of epochs must be specified as the epochs or maxiters kwarg."))
end

if cache.p == SciMLBase.NullParameters()
data = OptimizationBase.DEFAULT_DATA
else
if cache.p isa MLUtils.DataLoader
data = cache.p
dataiterate = true
else
data = [cache.p]
dataiterate = false
end
opt = cache.opt
θ = copy(cache.u0)
Expand All @@ -77,11 +79,16 @@ function SciMLBase.__solve(cache::OptimizationCache{
Optimization.@withprogress cache.progress name="Training" begin
for _ in 1:maxiters
for (i, d) in enumerate(data)
if cache.f.fg !== nothing
if cache.f.fg !== nothing && dataiterate
x = cache.f.fg(G, θ, d)
else
elseif dataiterate
cache.f.grad(G, θ, d)
x = cache.f(θ, d)
elseif cache.f.fg !== nothing
x = cache.f.fg(G, θ)
else
cache.f.grad(G, θ)
x = cache.f(θ)
end
opt_state = Optimization.OptimizationState(iter = i,
u = θ,
Expand Down
23 changes: 17 additions & 6 deletions src/sophia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ function SciMLBase.__solve(cache::OptimizationCache{

maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)

if cache.p == SciMLBase.NullParameters()
data = OptimizationBase.DEFAULT_DATA
else
if cache.p isa MLUtils.DataLoader
data = cache.p
dataiterate = true
else
data = [cache.p]
dataiterate = false
end

f = cache.f
Expand All @@ -77,14 +79,23 @@ function SciMLBase.__solve(cache::OptimizationCache{
hₜ = zero(θ)
for _ in 1:maxiters
for (i, d) in enumerate(data)
f.grad(gₜ, θ, d)
x = cache.f(θ, d)
if cache.f.fg !== nothing && dataiterate
x = cache.f.fg(G, θ, d)
elseif dataiterate
cache.f.grad(G, θ, d)
x = cache.f(θ, d)
elseif cache.f.fg !== nothing
x = cache.f.fg(G, θ)
else
cache.f.grad(G, θ)
x = cache.f(θ)
end
opt_state = Optimization.OptimizationState(; iter = i,
u = θ,
objective = first(x),
grad = gₜ,
original = nothing)
cb_call = cache.callback(θ, x...)
cb_call = cache.callback(opt_state, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
elseif cb_call
Expand Down
2 changes: 1 addition & 1 deletion test/minibatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function dudt_(u, p, t)
ann(u, p, st)[1] .* u
end

function callback(state, l) #callback function to observe training
function callback(state, l, pred) #callback function to observe training
display(l)
return false
end
Expand Down

0 comments on commit 5cf459a

Please sign in to comment.