-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
float8 training with rowwise scaling #889
Comments
I was thinking about similar topics for int8 training too. Just curious, how do you plan to handle backward pass for |
Well, at a high level, I want If I had to guess today what I expect to work well with FSDP, I'd say either tensorwise or blockwise scaling of the weight, so we can scale once and transpose without rescaling. In the future it would be great to have scaled float8 gemm support where one of the arguments is scaled tensorwise/blockwise (the weight), and the other rowwise. |
@gau-nernst I will say that supporting row-wise scaling with FSDP2 pre-all-gather is painful, and I am not sure if we should ever do it (at least with FSDP2 -- maybe for some other FSDP implementations, it can work). I have some more thoughts, but for one, if the user enables activation checkpointing, then the backward all-gather should now all-gather both for |
The recipe we've been working on does row-wise scaling of the weights post-all-gather (hence the comms happen in bf16), and tensor-wise scaling (+transposition) of the weights in the backward (leveraging the scaling factor of the forward to avoid recomputing a full amax from scratch). Ideally we would have done real column-wise scaling of weights in the backward but that was too hard. |
BTW, I plan to tackle some of these points in the next few days. I'll report here with where I get to. |
@vkuzo if we enable training for |
@jerryzh168 @vkuzo I'd appreciate clarity on this ASAP to avoid investing too much time in this if it's unnecessary, thanks! |
~hours to days of work, IMO
~weeks of alignment and work IMO, if we include aligning everyone that this is in fact what we want (I personally am not convinced just yet, AffineQuantizedTensor seems a bit too complicated for what float8 needs right now), and getting float8 training to the same level of robustness there as it is in torchao.float8. Given the timeline estimates above, I would just add rowwise scaling to |
work-in-progress is here: #940 |
This is a brain dump of what is missing from
torchao.float8
to support training with rowwise scaling, to help if someone wants to jump in to build this.already done
torch._scaled_mm
supports rowwise scalingmax-autotune
mode (I haven't personally tested this yet)needed
Float8Tensor
to work with rowwise scales. We had an unlanded PR onfloat8_experimental
doing that here ([wip] add axiswise granularity to Float8Tensor pytorch-labs/float8_experimental#352), just never got the time to land it. You can reuse that PR or do something similar. Note that [Float8Quant] Add rowwise scaling option to float8 dyanmic quant #819 landed recently adding float8 rowwise scaling to inference, so being consistent with that where applicable would be nice.Float8Linear
to be configurable with rowwise scales for each argument, and for the scaling to respect the config, validated by tests + benchmarks, would require changes totorchao.float8.config.py
andtorchao.float8.float8_linear.py
.torchao.float8
and inductor, if needed based on how well inductor generates the scaling codeThe text was updated successfully, but these errors were encountered: