Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callbacks wishlist #5

Open
4 of 5 tasks
dfalbel opened this issue May 11, 2021 · 7 comments
Open
4 of 5 tasks

Callbacks wishlist #5

dfalbel opened this issue May 11, 2021 · 7 comments

Comments

@dfalbel
Copy link
Member

dfalbel commented May 11, 2021

  • Model checkpoint callback
  • Early stopping
  • Learning rate schedulers
  • ndjson/csv logger
  • Gradient clipping
@mattwarkentin
Copy link
Contributor

luz_callback_csv_logger() saves a space-delimited file, not a comma-separated one. I think sep = "," needs to be added here:

luz/R/callbacks.R

Lines 564 to 570 in eada7e8

utils::write.table(
metrics,
file = self$path,
append = self$append,
col.names = !self$append,
row.names = FALSE
)

@mattwarkentin
Copy link
Contributor

Hi @dfalbel, I have a couple callback ideas, is this issue a good place to share and brainstorm??

@dfalbel
Copy link
Member Author

dfalbel commented Jul 21, 2021

Sure @mattwarkentin !

@mattwarkentin
Copy link
Contributor

I'm a big fan of pytorch-lightning and there are a few of their "trainer flags" (i.e. arguments) that I think can be implemented as luz callbacks and would be great additions.

  • luz_callback_validation_check(batches = 2)

    • Based on https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#num-sanity-val-steps.
    • This callback is super useful for diagnosing any issues with the validation loop without having to wait for the first epoch of the training loop to complete. Usually the training loop takes substantially longer than the validation loop and so you don't run into this error until you've already wasted a lot of time. Just a simple "sanity check" to make sure things are all working fine before jumping into a long runtime.
    • I think the basic implementation would be that this callback would run on_fit_begin and would apply the model to some small number of validation batches (2 is the default in pytorch-lightning) and assuming no issues, would then move on to the normal fitting loop.
    • I actually think this is a such a helpful diagnostic check that it may warrant being included in the default set of callbacks
  • luz_callback_overfit_batches(size = 0.1)

    • Based on https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#overfit-batches
    • This is another sort of "sanity check". Basically, if you can't overfit your model on a small training set then there is probably a bigger issue with either the data or the model. You would train and validate your model on a small proportion of the data or a small number of batches to achieve aggressive overfitting. This is good prototyping to make sure things are working well.
    • This sort of aggressive overfitting is considered good practice when starting a new project (example)
      • The approach I like to take to finding a good model has two stages: first get a model large enough that it can overfit (i.e. focus on training loss) and then regularize it appropriately (give up some training loss to improve the validation loss). The reason I like these two stages is that if we are not able to reach a low error rate with any model at all that may again indicate some issues, bugs, or misconfiguration.

    • This would be implemented similarly to the above callback, running on_fit_begin. In pytorch-lightning, the equivalent size argument is either between 0 and 1 specifying the proportion of the training data to use, or an integer value indicating how many training batches to use. Go through the whole fitting process using size amount of the training data for training AND validation with the training data loader automatically set shuffle = FALSE (if it isn't already) so literally the same small set of training data is used for overfitting

@dfalbel
Copy link
Member Author

dfalbel commented Jul 29, 2021

Thanks very much @mattwarkentin. I really like both callback ideas!

  • The validation check is very nice for UX and I agree with adding it to the list of default callbacks.

@mattwarkentin
Copy link
Contributor

Great to hear! I have a working version of luz_callback_validation_check() already, but it needs a little more work.

luz_callback_overfit_batches() has been a bit tougher since it requires modifying the dataloader. Hence my questions here: mlverse/torch#621

@dfalbel
Copy link
Member Author

dfalbel commented Jul 29, 2021

Hmm, yeah, that one is tricky. FWIW it seems that lightning replaces the sampler for that. See https://github.com/PyTorchLightning/pytorch-lightning/blob/a64cc373946a755ce6c3aef57c1be607dfe29a0c/pytorch_lightning/trainer/data_loading.py#L169-L228

The way they are doing is similar to: mlverse/torch#621 (comment) ie, creates a new data loader with all the arguments equal to the first one except from the ones you want the change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants