-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generic weight averaging callback that supports EMA [wip] #20545
base: master
Are you sure you want to change the base?
Conversation
Hey @senarvi, this looks great! I saw you already added support for saving and resuming which is great. There are many scenarios there (save every n steps, time-based, every epoch, etc) let's make sure we cover them all (for inspiration, we added quite a few tests here #20379)
No I think it's better to have one with configurable averaging flags, more lightning-esque
I think this is ok, but my doubt with forcing Wdyt about this? I don't necessarily want to make the implementation more complex, so this is just for discussion.
It would be nice to make it configurable, and probably users will want to get to some minimum and then start averaging. The criteria to do so may be very bespoke, so maybe allowing the user to implement a custom hook to decide whether to start averaging or whether to average at a given step would be super handy. Otherwise I'm expecting users will train for some time, save a checkpoint, then reload with this callback added to the trainer and start averaging. Which is totally fine but it requires you to stop and resume. Regarding removing the StochasticWeightAveraging callback, I don't necessarily see that happening. We have a pretty strong commitment to backward compatibility at this point, so keeping that in with a notice to just use this one will not hurt. |
That's a good point. I don't know what would be a good solution.
That's an interesting idea. We could have the user pass a function It seems that AveragedModel will copy the current model parameters when called the first time, and update the average on subsequent calls. This means that the first average is computed when I checked how StochasticWeightAveraging does this and I think it doesn't work correctly. It only ever updates the average model parameters in on_train_epoch_start(), so the average is not updated after the last epoch. Just shows why I'd like to keep the logic as simple as possible. |
Hi, I have a couple questions.
|
During training (stage=fit), the actual LightningModule is what we update using the optimizer (I call it the current model) and an AveragedModel is maintained in the background (I call it the average model). I assume that validation is only called during training. Before and after validation we swap the current model and the average model, so the average model will be validated. When saving a checkpoint, we save the average model parameters in the state_dict. So if you later load the checkpoint without WeightAveraging callback and run a test or export to ONNX, you will be using the average parameters. When training ends, we copy the average model parameters to the current model. So if you run a test or export to ONNX after training, you will be using the average parameters. That's the idea at least. I'm not confident that I have thought about every possible corner case. It would be great if you could test that it works in your case. |
@senarvi Ah! Thanks for the clarification, I should've checked the code out more carefully. I tried your branch out on a quantization aware training enabled model with ONNX export at the end and everything is working beautifully! I hope this gets merged quickly. |
efc77dc
to
0010492
Compare
The user can now provide either the For example: update_on_step = lambda x: x > 100 or update_on_epoch = lambda x: x in (3, 5, 7) Using I tested EMA in an actual learning task and it gave an improvement, so I'm starting to be more confident that this works. I think the biggest question that is still left is whether it's a problem that we force
@tchaton I think you contributed the |
* A callback that updates a torch.optim.swa_utils.AveragedModel after specific steps or epochs. * The user can provide a callback that defines after which steps or epochs the average model is updated.
5f34205
to
c8d50bd
Compare
A callback that updates an AveragedModel after every training step
What does this PR do?
This is similar to the existing StochasticWeightAveraging callback, but uses the AveragedModel class from PyTorch. Reduced code duplication means easier maintenance. Also, any averaging function can be used. Currently this callback does averaging on every step. We could make this callback support both SWA and EMA, or we could still have different callbacks ("StepwiseAveragingCallback" and "EpochwiseAveragingCallback"). The biggest questions:
Constructs the AveragedModel with
use_buffers=True
, so that an extra step is not needed for updating the batch normalization statistics.StochasticWeightAveraging
performs an extra step in the end. Consequently the implementation is significantly more complex and it's difficult to make sure that it works in all cases. Should we add this as an option in this class too?Updates the average model after every step.
StochasticWeightAveraging
updates the average model after every epoch, and I recall that the original paper updated it only at certain points (the learning rate minima). I guess it would be nice to be able to select whether the average model will be updated after every step, after every epoch, or after certain epochs. Then we would need only one callback and we could remove theStochasticWeightAveraging
callback, but would it make this class too complex?Fixes #10914
Before submitting
PR review
This pull request is still work in progress and opened for dicussion.
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--20545.org.readthedocs.build/en/20545/