Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUTLASS-based W4A4 #1515

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft

Add CUTLASS-based W4A4 #1515

wants to merge 6 commits into from

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Jan 7, 2025

Closes #1406

Thanks to #880, we now have a CUTLASS (3.6.0) copy in torchao. Adding W4A4 is pretty straight-forward, similar to how W4A8 is done. This is largely copied from my other repo, so I didn't exactly follow @alexsamardzic's style. Requesting a first round of review.

Note: this is more for doing experiments with W4A4 easier. Personally I don't think it's too useful at the moment, since W4A4 accuracy is probably quite bad.

TODO:

  • Hook up to AQT
  • Benchmark script and get benchmark results for A100 and 4090
  • (Maybe) Do some tuning + heuristics based on problem size

Copy link

pytorch-bot bot commented Jan 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1515

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fe1f0eb with merge base 4996101 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 7, 2025
@gau-nernst gau-nernst added the topic: new feature Use this tag if this PR adds a new feature label Jan 7, 2025
@alexsamardzic
Copy link
Collaborator

CUDA code looks fine, of course there are lots of dots to connect remaining on the Python side.

The difference from #880 is that this is not mixed data types GEMM, but regular GEMM instead. In that regard, this operator here is maybe easier to be made much more generic, to support other integer and maybe even some floating point input data types. I'm at the moment making some minor changes on this PyTorch operator, and would strongly recommend modelling CUDA code in alike way, as it plain looks nice, and then makes extending the kernel to other datatypes much easier, has extensive checks on operands, etc. Moreover, I think it would make sense at this point to discuss having a single CUTLASS-based kernel for GEMMs with both weights and activations scaled, to be put in the single source file, and to handle both same and mixed data types GEMMs, at least for SM 8.x archs - that would provide for minimum code duplication, and easier maintenance in the future.

As far as configurations (tile sizes, number of stages, etc.) concerned, I'd suggest looking here instead in the unit tests, and also comparing performance vs. results reported by CUTLASS profiler for given combination of data types. I believe some sort of tuning configuration on the input shapes is a must in order to achieve a decent performance; but I have to admit that in #880 the tuning is mostly ad-hoc (for comparison, I find this approach more elaborate and meaningful). Thus, I think that coming up with some kind of systematic approach in that regard would be the most beneficial contribution regarding eventual future use of CUTLASS-based kernels in the torchao.

(@drisspg: Your comments welcome here.)

@drisspg
Copy link
Contributor

drisspg commented Jan 7, 2025

One thing on finding optimal params is that @yifuwang was recently working on finding better configs for an AsyncMM. He did some manual elimination of configs that never seemed to be performant and then fit a simple decision Tree on a big sweep over MKN shapes that could be easily modeled in C++. This is similar to what is done in the RowWise scaling. I think a little flow for this would be helpful I can make an issue to track.

No major comments

@gau-nernst
Copy link
Collaborator Author

Thank you for the feedback.

In that regard, this operator here is maybe easier to be made much more generic, to support other integer and maybe even some floating point input data types

Though this is nice on paper, I think Triton is the better alternative for other data types (INT8, FP8...). It's more flexible and the autotuner also saves us some headache. Only because of the lack of INT4 support in Triton, we have to use Cutlass, especially for INT4 Tensor cores. Unless we can show that there are cases Triton cannot reach the perf of Cutlass (in the context of this PR, I'm only thinking about INT8 for SM8x, and additionally FP8 for SM89).

Having said that, I'm ok with following a certain style/structure. Just point me which one it should be, and I will make modifications accordingly.

Returns:
output: result tensor, in row-major layout.
"""
assert A.dtype == B.dtype == torch.int8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add the alignment constraints as well right?

Copy link
Collaborator Author

@gau-nernst gau-nernst Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should I check for data alignment from Python? I guess in C++, I can check by testing divisibility of the memory address? (or perhaps there is a util function somewhere that I'm not aware of...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think there is a restriction that k need to be a multiple of 32 right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or at least 16 packed int4 s

using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
// static int const kStages = 3;
using ElementC = int32_t;
using Gemm = cutlass::gemm::device::Gemm<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if the universal gemm api can be used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will look into it. I wrote this quite some time ago...

@alexsamardzic
Copy link
Collaborator

@gau-nernst

Attached is a minor patch that will change s8s4_linear_cutlass operator to do W4A4. I did it in a quick-a-dirty way, so maybe I made some kind of error (the tests run, but won't pass), but the point is that the main differences are in: checking arguments (I just commented out these), dispatching depending on input data types (I just changed the input data type in this patch instead of having full dispatching) and selecting configuration (for W4A4 I just put the same configuration there that you use in your patch). But the CUTLASS boilerplate code is completely the same, except for using OpMultiplyAddMixedInputUpcast as operator for mixed input data types case. So my point is that I think, for maintenance reasons, we better have single CUTLASS-based kernel for W4A8 and W4A4 (and then all of alike) instead of duplicating lots of code.

The structure of CUTLASS-based kernels is typically always the same (see also rowwise scaled MM in PyTorch, mentioned in my previous comment, as well as my CUTLASS-based mixed data types and 2:4 sparsity kernels in PyTorch): from the bottom up, there is always an operator implementation function that contains checking inputs, and then starting a dispatching chain (where run-time data types etc. are translated to compile-time template arguments), that ends up with a typical CUTLASS-based GEMM kernel (that is boilerplate). Also as mentioned in my previous comment, while rowwise scaled MM is very similar in structure, I like how it looks the most - because of clever use of variable template arguments to decrease the clutter, then because of clear extraction of input checks, and configuration selection into separate functions, etc. So I'd suggest we have your C++ code integrated in the way sketched by attached diff, and then also to made minor changes in the C++ code in a way to make it to look closer to rowwise scaled MM implementation. (Of course, operator name and some other stuff on Python side will have to be changed too.)

diff.txt

@alexsamardzic
Copy link
Collaborator

alexsamardzic commented Jan 8, 2025

As far as performance between various implementations concerned: I'd say in general there are three ways to implement kernels: Triton-based, CUTLASS-based, and custom i.e. from scratch (like Marlin-based kernels). In my experience so far (that was all for Ampere arch), CUTLASS-based kernels are oftentimes somewhat faster than Triton-based kernels, while then for some corner-case input tensor sizes, custom kernels (well, Marlin-based at least) could be significantly faster than CUTLASS-based ones. Furthermore, with Triton there is the least amount of flexibility with upstream changes (they just don't support some input data types, they don't support 2:4 sparsity, etc.), with CUTLASS it's somewhat easier to have changes we may need accepted, while for custom kernels obviously this is not an issue at all. However, Triton kills it when it comes to compilation, in particular regarding fusing GEMM with other kernels, then CUTLASS has some support for compilation but doing fusion is rather cumbersome at the moment, while obviously there is no any kind of compilation support for custom kernels. Then, doing custom kernels would probably lead to lots of code duplication, with CUTLASS this also may be an issue even if to the smaller extent. Etc. - so it's all matter of trade-offs. Still, having in mind auto-tuning and auto-quantization, I belive it still may be good to have as much different kernels in torchao as possible, so I'd expect more CUTLASS-based kernels to be written, besides these W4A8 and W4A4 kernels - and this is the exact reason that, as discussed above, I'd prefer to have as much code shared as possible between these kernels.

@supriyar
Copy link
Contributor

supriyar commented Jan 9, 2025

since W4A4 accuracy is probably quite bad.

Might be interesting to try out QAT with this setting cc @andrewor14

@alexsamardzic
Copy link
Collaborator

The structure of CUTLASS-based kernels is typically always the same (see also rowwise scaled MM in PyTorch, mentioned in my previous comment, as well as my CUTLASS-based mixed data types and 2:4 sparsity kernels in PyTorch): from the bottom up, there is always an operator implementation function that contains checking inputs, and then starting a dispatching chain (where run-time data types etc. are translated to compile-time template arguments), that ends up with a typical CUTLASS-based GEMM kernel (that is boilerplate). Also as mentioned in my previous comment, while rowwise scaled MM is very similar in structure, I like how it looks the most - because of clever use of variable template arguments to decrease the clutter, then because of clear extraction of input checks, and configuration selection into separate functions, etc. So I'd suggest we have your C++ code integrated in the way sketched by attached diff, and then also to made minor changes in the C++ code in a way to make it to look closer to rowwise scaled MM implementation. (Of course, operator name and some other stuff on Python side will have to be changed too.)

I've made these changes to existing CUTLASS-based W4A8 kernel in #1545, so it should be easier now to eventually include W4A4 functionality there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] W4A4 Quantization Support in torchao
5 participants