From 847280dbf54cc3b2f60971d1bdc5957087c25c79 Mon Sep 17 00:00:00 2001 From: Hao Jiang Date: Sun, 21 Apr 2024 21:49:33 +0800 Subject: [PATCH] initial support for jamba --- awq/models/__init__.py | 1 + awq/models/auto.py | 1 + awq/models/base.py | 3 +- awq/models/jamba.py | 107 ++++++++++++++++++++++++++++++++++++++ awq/quantize/quantizer.py | 6 +++ 5 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 awq/models/jamba.py diff --git a/awq/models/__init__.py b/awq/models/__init__.py index 2ae3fd55..daaee585 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -17,3 +17,4 @@ from .gemma import GemmaAWQForCausalLM from .stablelm import StableLmAWQForCausalLM from .starcoder2 import Starcoder2AWQForCausalLM +from .jamba import JambaAWQForCausalLM diff --git a/awq/models/auto.py b/awq/models/auto.py index 0a236979..c55ad4f6 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -26,6 +26,7 @@ "gemma": GemmaAWQForCausalLM, "stablelm": StableLmAWQForCausalLM, "starcoder2": Starcoder2AWQForCausalLM, + "jamba": JambaAWQForCausalLM, } diff --git a/awq/models/base.py b/awq/models/base.py index ebd45ccc..a668560c 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -70,6 +70,7 @@ "gemma": "AutoModelForCausalLM", "stablelm": "AutoModelForCausalLM", "starcoder2": "AutoModelForCausalLM", + "jamba": "AutoModelForCausalLM", } @@ -449,7 +450,7 @@ def from_quantized( model, checkpoint=model_weights_path, device_map=device_map, - no_split_module_classes=[self.layer_type], + no_split_module_classes=[self.layer_type] if isinstance(self.layer_type, str) else self.layer_type, offload_folder=offload_folder, dtype=torch_dtype, ) diff --git a/awq/models/jamba.py b/awq/models/jamba.py new file mode 100644 index 00000000..efb553ca --- /dev/null +++ b/awq/models/jamba.py @@ -0,0 +1,107 @@ +import tqdm +import torch +from typing import List, Tuple, Union +from .base import BaseAWQForCausalLM +from awq.utils.fused_utils import fuse_qkv +from transformers.models.jamba.modeling_jamba import ( + JambaAttentionDecoderLayer as OldJambaAttentionDecoderLayer, + JambaMambaDecoderLayer as OldJambaMambaDecoderLayer, + JambaForCausalLM as OldJambaForCausalLM, +) +from awq.modules.fused.norm import FasterTransformerRMSNorm + + +class JambaAWQForCausalLM(BaseAWQForCausalLM): + layer_type = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"] + max_seq_len_key = "max_position_embeddings" + modules_to_not_convert = ["mamba", "router"] + + @staticmethod + def get_model_layers(model: OldJambaForCausalLM): + return model.model.layers + + @staticmethod + def get_act_for_scaling(module: Union[OldJambaMambaDecoderLayer, OldJambaAttentionDecoderLayer]): + return dict(is_scalable=False) + + @staticmethod + def move_embed(model: OldJambaForCausalLM, device: str): + model.model.embed_tokens = model.model.embed_tokens.to(device) + + @staticmethod + def get_layers_for_scaling(module: Union[OldJambaMambaDecoderLayer, OldJambaAttentionDecoderLayer], input_feat, module_kwargs): + layers = [] + + # attention input + if isinstance(module, OldJambaAttentionDecoderLayer): + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) + + # attention out + # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + + if hasattr(module.feed_forward, "router"): + # linear in + layers.append( + dict( + prev_op=module.pre_ff_layernorm, + layers=[ + w + for expert in module.feed_forward.experts + for w in [expert.gate_proj, expert.up_proj] + ], + inp=input_feat["feed_forward"], + module2inspect=module.feed_forward, + ) + ) + + # linear out + for i, expert in enumerate(module.feed_forward.experts): + layers.append( + dict( + prev_op=expert.up_proj, + layers=[expert.down_proj], + inp=input_feat[f"feed_forward.experts.{i}.down_proj"], + ) + ) + + else: + # linear 1 + layers.append( + dict( + prev_op=module.pre_ff_layernorm, + layers=[module.feed_forward.gate_proj, module.feed_forward.up_proj], + inp=input_feat["feed_forward.gate_proj"], + module2inspect=module.feed_forward, + ) + ) + + # linear 2 + layers.append( + dict( + prev_op=module.feed_forward.up_proj, + layers=[module.feed_forward.down_proj], + inp=input_feat["feed_forward.down_proj"], + ) + ) + + return layers diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index 6a4574e6..cd6658b7 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -521,6 +521,12 @@ def cache_input_hook(m, x, y, name, feat_dict): "block_sparse_moe": layer.block_sparse_moe, } + if self.awq_model.model_type == "jamba": + named_linears = { + **named_linears, + "feed_forward": layer.feed_forward, + } + for name in named_linears: handles.append( named_linears[name].register_forward_hook(