-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Sparsity Future Plans #1136
Comments
Nice plan, I am working on hipSPARSELt support and plan to release that at rocm 6.3/6.4-time frame. |
* update install_requirement.sh * bring torchao back to all platform * add essential comment * enable torchtune on macos * reformat
cc @alexsamardzic @supriyar @cpuhrsch 2:4 CUTLASS GEMM work optionsPreviously, we needed to use either cuSPARSELt or the 2:4 CUTLASS kernels in core for accelerated 2:4 sparse matmul, but with the addition of #880 we have the ability to land custom CUTLASS kernels into torchAO. There are a couple of interesting things we could do: H100
A100
|
@jcaip Sorry for the late reply, I'm taking on item 1 from your list next. |
cc @alexsamardzic Great, vllm-project/vllm#10995 may be helpful here as well. |
@alexsamardzic here is a paste of some code we've been using (internally) in xFormers for 24-fp8 rowwise scaling: https://pastebin.com/kFTG3bu9 |
I had a chance to reflect after PTC / CUDA-MODE and wanted to share some thoughts on future plans for sparsity in torchao.
Current State
There are two components of sparsity, accuracy and acceleration.
torchao has primarily focused on the acceleration side, with our 2:4 and block sparse kernels and quantization composition. We do offer some prototype features (WeightNormSparsifier, superblock) for accuracy but these are not our main focus. We felt that these GPU sparsity patterns offered a “godilocks” sweet spot in the accuracy/acceleration tradeoff and focused on making them easily accelerable.
We still believe this to be true, but research in accuracy has evolved from when we first released these APIs. While post-training methods continue to be researched, there’s been a lot of development in 1) distillation and 2) activation sparsity, which we’d like to better support moving forward.
Proposal
There are two things concretely that we’d like to do:
1) Run sparsity (2:4, block sparse) + distillation experiments for LLaMA 3.
NVIDIA was able to use distillation + pruning to halve the size of LLaMa 3.1, and the new LLaMa 3.2 1B models were trained with distillation.
Instead of shrinking the model with pruning, we can also do so with weight sparsity. Theoretically, the extra granularity given to us by block / semi-structured sparsity would allow for us to push the overall sparsity level of the models higher. For example, NeuralMagic has trained their 2:4 sparse models LLMs using distillation. We can adapt the torchtune distillation tutorial to get some initial benchmarks.
2) Fast run-time compression routines for 2:4 sparsity
The idea of activation sparsity papers like CATS / TEAL is to swap out the SwiGLU activation for a ReLU-esque (zero-based) activation function, which can yield up to 80% sparsity in the activations.
In the memory bound case, we can accelerate mamtul (GEMV) by selective loading only the rows of
W
that we need, which is very similar to how we achieve speedups with row-wise structured sparsity. Researchers have been looking into using 2:4 sparsity instead for speedups in both the memory-bound and compute-bound case.For activation sparsity, this means we need run-time sparsification routines, since we can’t compress the tensor offline. We’ve written these routines for 2:4 sparse training, but via custom CUDA kernels.
We’d like to see if we can expand these routines to be torch.compile able, so that they could be fused as part of other calculations, like per-row FP8 scaling. This will also allow us to explore future optimizations, i.e. fusing a permutation into the model so that we can accelerate more of the computation, using an approach similar to https://arxiv.org/pdf/2408.11551 or adding support for new dtypes like FP8 / INT8.
An open question here: Would support for more flexible accuracy patterns via a permuted / shuffled 2:4 sparse / block sparse matrix be interesting?
The way I see it, we can potentially get the best of both worlds by shuffling unstructured sparsity to an accelerable GPU sparse format.
Run-time activation sparsity is also important for TTFT for LLM inference. When we are compute-bound in prefill, we can use 2:4 sparsity to speedup up inference. However, once we start decoding we are memory-bound, and would no longer like to use our sparse kernels. A run-time sparsification kernel would let us use the same model for prefill + decode.
I ran some preliminary benchmarks for FP8 + 2:4 sparsity on LLaMa3 and observed the following speedups on TTFT.
Work Items
Nice to haves:
Additional Notes
If we create torch.compilable compression routines, we may be able to use them to accelerate the memory bound case by fusing together a mm with the decompression routine, basically creating a load-as-sparse compute-as-dense kernel.
The text was updated successfully, but these errors were encountered: