-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement luz_callback_validation_check #56
base: main
Are you sure you want to change the base?
Conversation
if (is.null(ctx$valid_data)) return() | ||
if (self$batches <= 0) return() | ||
|
||
ctx$model$eval() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to extract this out into a function:
Lines 298 to 300 in 6e0bb77
ctx$model$eval() | |
ctx$training <- FALSE | |
ctx$loss <- list() |
And reuse here so we make sure that the same changes are always set?
input <- list(batch[[1]]) | ||
target <- batch[[2]] | ||
pred <- do.call(ctx$model, input) | ||
self$loss <- ctx$model$loss(pred, target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in general we would want to do the full validation step because the errors could be in any of the callbacks etc. But we would need to take care of the side effects that this might cause.
We would need to call valid_one_step()
and then make sure we can reset the state. Not sure yet what would be the best way to do it though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about this problem. It feels to me that the safest way would be to on_fit_begin()
we call fit
again and we add a callback that breaks the training loop after batches
steps for both training and validation. This way, no side effects would interfere in the actual training loop but we still run the full loop which would detect the other possible bugs.
I think this is possible if the first thing we do in the ctx
object is to save a list with all arguments that were passed to
fit
, before we do any kind of manipulation (like we do for callbacks).
To avoid the infinite recursion we could check ctx$callbacks
to check if the callback that breaks the loop is present.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I've kind of gone in circles here. We want to call the validation callbacks so it is a complete check of the validation loop, but I was worried about any changes in state this might have. I did consider using valid_one_batch()
but at the time decided against it for the above reasons.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somewhat related question, when ctx$call_callbacks("on_..._...")
is called, if there are multiple callbacks with available methods for the breakpoint, what is the order they are called in? Default callbacks first, user-supplied callbacks second?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, they are called in that order: default callbacks then user callbacks.
I think that if we call fit
again, there would be no interference, the only difference is that it would also test the training loop. But we could also skip it anyway...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did actually think about calling fit()
again inside on_fit_begin()
but I decided against. But you're right, it would be a good way to check both the training and validation loops before committing to a full fit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking again, there could still be some side effects, eg: the callbacks passed by the user can have side effects outside of the R session (maybe writing to a file or something like this). So maybe we want to call fit again, only with the default callbacks + the one that breaks the training loop.
This is not completely ideal, because still there would be callbacks that could fail in the 'real' pass. Bu sounds like enough, I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with only calling the default callbacks. My original reason for avoiding callbacks was for loggers and other things written to disk. But if we only run default callbacks we can avoid this issue. The function docs can just point out that user callbacks aren't validated.
Related to #5 (comment)
Hi @dfalbel,
This is a first attempt at implementing the validation check callback. It may still need some work so I am submitting this as a draft PR for now. By design, this check only runs a few batches and computes the loss. It does not strictly follow the standard validation loop because it does not call the validation-related callbacks.
We may also want to compute the validation metrics in this check. I will await your thoughts before more changes are made.