-
-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
iterating through a DataLoader is still broken #843
Comments
Similarly on this line
this should probably be iter = i + (epoch - 1) * length(data) instead
|
Just to continue this thread, MWE function callback(state, l)
@show state.iter
if state.iter % 10 == 2
println("stopping training!")
return true
else
return false
end
end
function lossf(θ, data)
return sum(θ.^2)
end
dataloader = DataLoader(collect(1:10), batchsize = 1)
opt_func = OptimizationFunction(
lossf,
Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, [2.0], 1)
solve(opt_prob, Optimisers.Adam(0.001); callback, epochs = 5) Which returns
|
@Vaibhavdixit02 you got this one? |
From your output there, it seems to be doing the right thing? Each element of dataloader shows up thrice in the print out so that matches the expectation I think, but maybe I am misunderstanding something.
Yeah we should switch to a |
No, the dataloader has 10 elements, but only the first 3 show up. You would expect it iterates 3 times through all 10 elements, but instead it iterates 3 times through the first 3 elements. |
Oh I missed that sorry, that makes sense it's a bug I am on it! |
Everything here has been fixed, the iteration counts and nested breaking too |
Describe the bug 🐞
Iterating over a
DataLoader
is still broken on 4.0.3. Ifepochs
is less than the number of elements inDataLoader
, only the firstepochs
elements get evaluated.Expected behavior
Surely this should iterate over the entire dataset
epochs
timesMinimal Reproducible Example 👇
Error & Stacktrace⚠️
No stacktrace, but julia outputs this:
Environment (please complete the following information):
Latest versions of Optimization and OptimizationOptimisers
Additional context
The problem is in this line:
Optimization.jl/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Line 115 in f6dd24d
Probably the if condition should be something like
epoch == maxiters && i == length(data)
. It would also be good to add a test.Related to #835 and #842
The text was updated successfully, but these errors were encountered: