Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Sparsity Future Plans #1136

Open
jcaip opened this issue Oct 22, 2024 · 5 comments
Open

[RFC] Sparsity Future Plans #1136

jcaip opened this issue Oct 22, 2024 · 5 comments
Labels

Comments

@jcaip
Copy link
Contributor

jcaip commented Oct 22, 2024

I had a chance to reflect after PTC / CUDA-MODE and wanted to share some thoughts on future plans for sparsity in torchao.

Current State

There are two components of sparsity, accuracy and acceleration.

  • Accuracy consists of finding which parameters to zero-out/remove and how to recover model task performance afterwards.
  • Acceleration consists of removing computation from the model after zeroing out parameters to make it fast.

torchao has primarily focused on the acceleration side, with our 2:4 and block sparse kernels and quantization composition. We do offer some prototype features (WeightNormSparsifier, superblock) for accuracy but these are not our main focus. We felt that these GPU sparsity patterns offered a “godilocks” sweet spot in the accuracy/acceleration tradeoff and focused on making them easily accelerable.

We still believe this to be true, but research in accuracy has evolved from when we first released these APIs. While post-training methods continue to be researched, there’s been a lot of development in 1) distillation and 2) activation sparsity, which we’d like to better support moving forward.

Proposal

There are two things concretely that we’d like to do:

1) Run sparsity (2:4, block sparse) + distillation experiments for LLaMA 3.

NVIDIA was able to use distillation + pruning to halve the size of LLaMa 3.1, and the new LLaMa 3.2 1B models were trained with distillation.

Instead of shrinking the model with pruning, we can also do so with weight sparsity. Theoretically, the extra granularity given to us by block / semi-structured sparsity would allow for us to push the overall sparsity level of the models higher. For example, NeuralMagic has trained their 2:4 sparse models LLMs using distillation. We can adapt the torchtune distillation tutorial to get some initial benchmarks.

2) Fast run-time compression routines for 2:4 sparsity

The idea of activation sparsity papers like CATS / TEAL is to swap out the SwiGLU activation for a ReLU-esque (zero-based) activation function, which can yield up to 80% sparsity in the activations.

In the memory bound case, we can accelerate mamtul (GEMV) by selective loading only the rows of W that we need, which is very similar to how we achieve speedups with row-wise structured sparsity. Researchers have been looking into using 2:4 sparsity instead for speedups in both the memory-bound and compute-bound case.

For activation sparsity, this means we need run-time sparsification routines, since we can’t compress the tensor offline. We’ve written these routines for 2:4 sparse training, but via custom CUDA kernels.

We’d like to see if we can expand these routines to be torch.compile able, so that they could be fused as part of other calculations, like per-row FP8 scaling. This will also allow us to explore future optimizations, i.e. fusing a permutation into the model so that we can accelerate more of the computation, using an approach similar to https://arxiv.org/pdf/2408.11551 or adding support for new dtypes like FP8 / INT8.

An open question here: Would support for more flexible accuracy patterns via a permuted / shuffled 2:4 sparse / block sparse matrix be interesting?

The way I see it, we can potentially get the best of both worlds by shuffling unstructured sparsity to an accelerable GPU sparse format.

Run-time activation sparsity is also important for TTFT for LLM inference. When we are compute-bound in prefill, we can use 2:4 sparsity to speedup up inference. However, once we start decoding we are memory-bound, and would no longer like to use our sparse kernels. A run-time sparsification kernel would let us use the same model for prefill + decode.

I ran some preliminary benchmarks for FP8 + 2:4 sparsity on LLaMa3 and observed the following speedups on TTFT.

Prefill batch size Sparse TTFT Dense TTFT Speedup
8192 0.2643 sec 0.2968 sec 1.12x
16384 0.9132 sec 0.9664 sec 1.06x
32768 3.2265 sec 3.4839 sec 1.08x

Work Items

Nice to haves:

  • [TODO] Add distillation experiments
  • [INPROGRESS] hipSPARSELt support (?)
  • [TODO] Add fp8 support for runtime-sparsification kernels

Additional Notes

If we create torch.compilable compression routines, we may be able to use them to accelerate the memory bound case by fusing together a mm with the decompression routine, basically creating a load-as-sparse compute-as-dense kernel.

@jainapurva jainapurva added the rfc label Oct 22, 2024
@petrex
Copy link
Collaborator

petrex commented Oct 29, 2024

Nice plan, I am working on hipSPARSELt support and plan to release that at rocm 6.3/6.4-time frame.

yanbing-j pushed a commit to yanbing-j/ao that referenced this issue Dec 9, 2024
* update install_requirement.sh

* bring torchao back to all platform

* add essential comment

* enable torchtune on macos

* reformat
@jcaip
Copy link
Contributor Author

jcaip commented Dec 26, 2024

cc @alexsamardzic @supriyar @cpuhrsch

2:4 CUTLASS GEMM work options

Previously, we needed to use either cuSPARSELt or the 2:4 CUTLASS kernels in core for accelerated 2:4 sparse matmul, but with the addition of #880 we have the ability to land custom CUTLASS kernels into torchAO. There are a couple of interesting things we could do:

H100

  1. Add support for row-wise scaled MM for FP8 with 2:4 sparsity. Currently cuSPARSELt lacks the ability to fuse any scalar multiplies into the matmul, so we would need to fuse scaling matmul into the surrounding ops. This works fine for tensor-wise FP8 scaling, but not more granular scaling (row-wise / group-wise) This is a H100 focused feature, and something that would help with the activation sparsity 2:4 work @danthe3rd. I belive that https://github.com/huyz2023/2by4-pretrain also is trying to do something similar.

A100

  1. Enable int8+2:4 sparse attention support Currently we are limited to only applying int8 quantization + 2:4 sparsity to the MLP layers only of ViT-h, because we rely on cuSPARELt for fusing the scalar multiplications into the matmul. We'd like to do the same for the attention blocks, but because those are all fused, we cannot use cuSPARSELt without a performance degredation. We know that we can recover the accuracy degradation in the attention blocks, as NeuralMagic's 2:4 sparse model sparsified both the attention and mlp weights.

  2. 2:4 sparse GEMM using the MARLIN int4 wo packed format Currently we have int4wo + 2:4 sparse (MARLIN) for decode, but we see higher speedups on prefill workloads with the int8 cuSPARSELt / CUTLASS kernels (Add TTFT benchmarks + update sparsity benchmarks #1140). However, the weights are packed in a different format for MARLIN vs cuSPARSELt / CUTLASS, so currently we keep 2 copies of the weights, one for prefill and one for decode. If we can write either a way to translate between packed formats efficiently or a CUTLASS int8int4GEMM that uses the marlin format, we can use the same weights and get max speedups for prefill and decode.

@alexsamardzic
Copy link
Collaborator

@jcaip Sorry for the late reply, I'm taking on item 1 from your list next.

@jcaip
Copy link
Contributor Author

jcaip commented Jan 16, 2025

cc @alexsamardzic Great, vllm-project/vllm#10995 may be helpful here as well.

@danthe3rd
Copy link

@alexsamardzic here is a paste of some code we've been using (internally) in xFormers for 24-fp8 rowwise scaling: https://pastebin.com/kFTG3bu9
Feel free to start from this if this helps :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants