Skip to content

Latest commit

 

History

History
 
 

floatx

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 

Quant-LLM

This is a FP16 x Floatx mixed matmul kernel optimized for io bound workloads per FP6-LLM. The actual CUDA kernel is located under csrc/cuda/fp6_llm/. This module provides helper functions to quantize FP32/FP16/BF16 weights to Floatx and integration with torchao API.

Usage

from torchao.quantization import (
    quantize_,
    fpx_weight_only,
)

model = ...
model.half()  # not necessary, but recommeneded to maintain accuracy

# for generic Floatx EyMz where x = 1 + y + z
# fp6 with ebits = 3 and mbits = 2
quantize_(model, fpx_weight_only(3, 2))

# fully compatible with torch.compile()
model.compile(mode="max-autotune", fullgraph=True)

It's also possible to pre-process the weight and call the kernel directly.

import torch
from torchao.dtypes.floatx import to_scaled_tc_floatx
from torchao.ops import quant_llm_linear

fp32_weight = torch.randn(1024, 512).cuda()
ebits, mbits = 3, 2

# pre-process the weight. this will quantize the weight to FP6 and pack it in a special
# layout for tensor cores. refer to paper for more details.
fp6_weight, scales = to_scaled_tc_floatx(fp32_weight, ebits, mbits)

fp16_act = torch.randn(1, 512).cuda().half()
outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales)  # shape (1, 1024)

NOTE:

  • Since this kernel's computation dtype is FP16, it is recommended to convert the model to FP16 (instead of BF16) before applying quantization and use FP16 for activations.
  • Only FP6 E3M2 and FP5 E2M2 are tested and enabled in the official repo. We additionally enable support for FP6 E2M3 and FP5 E3M1.
  • On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See usyd-fsalab/fp6_llm#8 for a detailed discussion. See pytorch#223 for some microbenchmark results.

End-to-End benchmarks

Benchmarks are run on a machine with a single 4070Ti SUPER GPU using the scripts in _models/llama. tokens/s is measured using generate.py which generates text in a latency optimized way (batchsize=1). wikitext perplexity is measured using eval.py which uses lm_eval. The model used is meta-llama/Llama-2-7b-chat-hf.

Floatx quantization is run with --precision float16. The rest uses the default precision of bfloat16.

Quantization wikitext perplexity tokens/s
INT8 12.21 87.45
INT4-256 (tinygemm) -- 157.10
FP6 E3M2 12.34 106.76
FP6 E2M3 12.23 106.77
FP5 E3M1 12.55 122.69
FP5 E2M2 12.47 122.66
FP4 E3M0 14.58 145.55
FP4 E2M1 15.01 146.05
FP3 E2M0 74625.18 164.49

Credits

Credits to FP6-LLM authors