Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

testing HQQ [not for land] #155

Open
wants to merge 2 commits into
base: gh/HDCharles/9/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,16 @@ def _load_model(checkpoint_path, device, precision, use_tp):
simple_quantizer = WeightOnlyInt8QuantHandler(model)
model = simple_quantizer.convert_for_runtime()

if "int4" in str(checkpoint_path):
if "int4-hqq" in str(checkpoint_path):
print("Using int4 weight-only HQQ quantization.")
from quantize import WeightOnlyInt4HqqQuantHandler
path_comps = checkpoint_path.name.split(".")
assert path_comps[-3].startswith("g")
assert path_comps[-2] in device, "weight packed format mismatch, please rerun quantize.py!"
groupsize = int(path_comps[-3][1:])
quantizer = WeightOnlyInt4HqqQuantHandler(model, groupsize=groupsize)
model = quantizer._convert_for_runtime()
elif "int4" in str(checkpoint_path):
print("Using int4 weight-only quantization!")
path_comps = checkpoint_path.name.split(".")
assert path_comps[-3].startswith("g")
Expand Down
41 changes: 40 additions & 1 deletion quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,33 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)

# TODO a hacky placeholder class
class WeightOnlyInt4HqqQuantHandler:
def __init__(self, mod, groupsize):
self.mod = mod
self.groupsize = groupsize

def _create_quantized_state_dict(self):
from hqq.core.quantize import Quantizer # TODO maybe torchao

for m in self.mod.modules():
for name, child in m.named_children():
if isinstance(child, torch.nn.Linear):
child.weight = torch.nn.Parameter(
Quantizer.dequantize(
*Quantizer.quantize(
child.weight,
nbits=4,
group_size=self.groupsize,
axis=1,
)
)
)

return WeightOnlyInt4QuantHandler(self.mod, self.groupsize).create_quantized_state_dict()

def _convert_for_runtime(self):
return WeightOnlyInt4GPTQQuantHandler(self.mod, self.groupsize).convert_for_runtime(use_cuda=True)

def quantize(
checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
Expand Down Expand Up @@ -592,6 +619,18 @@ def quantize(
dir_name = checkpoint_path.parent
base_name = checkpoint_path.name
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.{device}.pth")

elif mode == 'int4-hqq':
print("Quantizing model weights for int4 using HQQ")
quant_handler = WeightOnlyInt4HqqQuantHandler(model, groupsize)
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))

quantized_state_dict = quant_handler._create_quantized_state_dict()
dir_name = checkpoint_path.parent
base_name = checkpoint_path.name
new_base_name = base_name.replace('.pth', f"{label}int4-hqq.g{groupsize}.{device}.pth")
else:
raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")

Expand All @@ -606,7 +645,7 @@ def quantize(
import argparse
parser = argparse.ArgumentParser(description='Quantize a model.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.')
parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform')
parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq', 'int4-hqq'], help='type of quantization to perform')
parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.')
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration')
Expand Down
27 changes: 27 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf

# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile # working
# echo "base"
# export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext --limit 5


python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-hqq
# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-hqq.g32.cuda.pth --compile
python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-hqq.g32.cuda.pth --tasks wikitext

python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
# python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --compile
python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext

# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5
# broken

# export MODEL_REPO=meta-llama/Llama-2-70b-chat-hf
# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4
# python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext --limit 5
# ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth

# python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 5