Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

iterating through a DataLoader is still broken #843

Closed
tiemvanderdeure opened this issue Oct 17, 2024 · 7 comments
Closed

iterating through a DataLoader is still broken #843

tiemvanderdeure opened this issue Oct 17, 2024 · 7 comments
Labels
bug Something isn't working

Comments

@tiemvanderdeure
Copy link

Describe the bug 🐞
Iterating over a DataLoader is still broken on 4.0.3. If epochs is less than the number of elements in DataLoader, only the first epochs elements get evaluated.

Expected behavior
Surely this should iterate over the entire dataset epochs times

Minimal Reproducible Example 👇

function lossf(θ, data)
    @show data
    return sum.^2)
end
dataloader = DataLoader(collect(1:10), batchsize = 1)
opt_func = OptimizationFunction(
    lossf,
    Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, [2.0], dataloader)
res_adam = solve(
    opt_prob, Optimisers.Adam(0.001); epochs = 3)

Error & Stacktrace ⚠️
No stacktrace, but julia outputs this:

data = [1]
data = [2]
data = [3]
data = [3]
data = [1]
data = [2]
data = [3]
data = [3]
data = [1]
data = [2]
data = [3]
data = [3]
retcode: Default
u: 1-element Vector{Float64}:
 1.9940003690339771

Environment (please complete the following information):
Latest versions of Optimization and OptimizationOptimisers

Additional context
The problem is in this line:

if i == maxiters #Last iter, revert to best.

Probably the if condition should be something like epoch == maxiters && i == length(data). It would also be good to add a test.

Related to #835 and #842

@tiemvanderdeure tiemvanderdeure added the bug Something isn't working label Oct 17, 2024
@tiemvanderdeure tiemvanderdeure changed the title epochs is still broken iterating through a DataLoader is still broken Oct 17, 2024
@tiemvanderdeure
Copy link
Author

Similarly on this line

opt_state = Optimization.OptimizationState(iter = i,

this should probably be iter = i + (epoch - 1) * length(data) instead

@tiemvanderdeure
Copy link
Author

tiemvanderdeure commented Oct 18, 2024

Just to continue this thread, callback also doesn't work as expected. If callback returns false, this only breaks the inner loop, meaning iteration through the data is reset, but the solver continues on the next epoch regardless. If the data is something else than a dataloader, this means returning false from the callback doesn't do anything at all.

MWE

function callback(state, l)
    @show state.iter
    if state.iter % 10 == 2
        println("stopping training!")
        return true
    else
        return false
    end
end
function lossf(θ, data)
    return sum.^2)
end
dataloader = DataLoader(collect(1:10), batchsize = 1)
opt_func = OptimizationFunction(
    lossf,
    Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, [2.0], 1)
solve(opt_prob, Optimisers.Adam(0.001); callback, epochs = 5)

Which returns

state.iter = 1
state.iter = 2
stopping training!
state.iter = 3
state.iter = 4
state.iter = 5
state.iter = 5
retcode: Default
u: 1-element Vector{Float64}:
 1.9970000478972731

@ChrisRackauckas
Copy link
Member

@Vaibhavdixit02 you got this one?

@Vaibhavdixit02
Copy link
Member

If epochs is less than the number of elements in DataLoader, only the first epochs elements get evaluated.

Expected behavior
Surely this should iterate over the entire dataset epochs times

From your output there, it seems to be doing the right thing? Each element of dataloader shows up thrice in the print out so that matches the expectation I think, but maybe I am misunderstanding something.

Just to continue this thread, callback also doesn't work as expected. If callback returns false, this only breaks the inner loop, meaning iteration through the data is reset, but the solver continues on the next epoch regardless. If the data is something else than a dataloader, this means returning false from the callback doesn't do anything at all.

Yeah we should switch to a @goto there 👍

@tiemvanderdeure
Copy link
Author

Each element of dataloader shows up thrice in the print out so that matches the expectation I think, but maybe I am misunderstanding something.

No, the dataloader has 10 elements, but only the first 3 show up. You would expect it iterates 3 times through all 10 elements, but instead it iterates 3 times through the first 3 elements.

@Vaibhavdixit02
Copy link
Member

Oh I missed that sorry, that makes sense it's a bug I am on it!

@Vaibhavdixit02
Copy link
Member

julia> function lossf(θ, data)
           @show data
           return sum(θ.^2)
       end
lossf (generic function with 1 method)

julia> dataloader = DataLoader(collect(1:10), batchsize = 1)
10-element DataLoader(::Vector{Int64})
  with first element:
  1-element Vector{Int64}

julia> opt_func = OptimizationFunction(
           lossf,
           Optimization.AutoZygote())
(::OptimizationFunction{true, AutoZygote, typeof(lossf), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}) (generic function with 1 method)

julia> opt_prob = OptimizationProblem(opt_func, [2.0], dataloader)
OptimizationProblem. In-place: true
u0: 1-element Vector{Float64}:
 2.0

julia> res_adam = solve(
           opt_prob, Optimisers.Adam(0.001); epochs = 3)
data = [1]
data = [2]
data = [3]
data = [4]
data = [5]
data = [6]
data = [7]
data = [8]
data = [9]
data = [10]
data = [1]
data = [2]
data = [3]
data = [4]
data = [5]
data = [6]
data = [7]
data = [8]
data = [9]
data = [10]
data = [1]
data = [2]
data = [3]
data = [4]
data = [5]
data = [6]
data = [7]
data = [8]
data = [9]
data = [10]
retcode: Default
u: 1-element Vector{Float64}:
 1.97003760066008

Everything here has been fixed, the iteration counts and nested breaking too

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants