This is a FP16 x Floatx mixed matmul kernel optimized for io bound workloads per FP6-LLM. The actual CUDA kernel is located under csrc/cuda/fp6_llm/. This module provides helper functions to quantize FP32/FP16/BF16 weights to Floatx and integration with torchao API.
from torchao.quantization import (
quantize_,
fpx_weight_only,
)
model = ...
model.half() # not necessary, but recommeneded to maintain accuracy
# for generic Floatx EyMz where x = 1 + y + z
# fp6 with ebits = 3 and mbits = 2
quantize_(model, fpx_weight_only(3, 2))
# fully compatible with torch.compile()
model.compile(mode="max-autotune", fullgraph=True)
It's also possible to pre-process the weight and call the kernel directly.
import torch
from torchao.dtypes.floatx import to_scaled_tc_floatx
from torchao.ops import quant_llm_linear
fp32_weight = torch.randn(1024, 512).cuda()
ebits, mbits = 3, 2
# pre-process the weight. this will quantize the weight to FP6 and pack it in a special
# layout for tensor cores. refer to paper for more details.
fp6_weight, scales = to_scaled_tc_floatx(fp32_weight, ebits, mbits)
fp16_act = torch.randn(1, 512).cuda().half()
outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape (1, 1024)
NOTE:
- Since this kernel's computation dtype is FP16, it is recommended to convert the model to FP16 (instead of BF16) before applying quantization and use FP16 for activations.
- Only FP6 E3M2 and FP5 E2M2 are tested and enabled in the official repo. We additionally enable support for FP6 E2M3 and FP5 E3M1.
- On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See usyd-fsalab/fp6_llm#8 for a detailed discussion. See pytorch#223 for some microbenchmark results.
Benchmarks are run on a machine with a single 4070Ti SUPER GPU using the scripts in _models/llama. tokens/s is measured using generate.py which generates text in a latency optimized way (batchsize=1). wikitext perplexity is measured using eval.py which uses lm_eval. The model used is meta-llama/Llama-2-7b-chat-hf.
Floatx quantization is run with --precision float16
. The rest uses the default precision of bfloat16
.
Quantization | wikitext perplexity | tokens/s |
---|---|---|
INT8 | 12.21 | 87.45 |
INT4-256 (tinygemm) | -- | 157.10 |
FP6 E3M2 | 12.34 | 106.76 |
FP6 E2M3 | 12.23 | 106.77 |
FP5 E3M1 | 12.55 | 122.69 |
FP5 E2M2 | 12.47 | 122.66 |
FP4 E3M0 | 14.58 | 145.55 |
FP4 E2M1 | 15.01 | 146.05 |
FP3 E2M0 | 74625.18 | 164.49 |
Credits to FP6-LLM authors