diff --git a/awq/models/__init__.py b/awq/models/__init__.py index 79ca150e..8c9d95c5 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -18,6 +18,7 @@ from .gemma2 import Gemma2AWQForCausalLM from .stablelm import StableLmAWQForCausalLM from .starcoder2 import Starcoder2AWQForCausalLM +from .jamba import JambaAWQForCausalLM from .llava_next import LlavaNextAWQForCausalLM from .phi3 import Phi3AWQForCausalLM from .phi3_v import Phi3VAWQForCausalLM diff --git a/awq/models/auto.py b/awq/models/auto.py index 5f6378f7..0afad0eb 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -28,6 +28,7 @@ "gemma2": Gemma2AWQForCausalLM, "stablelm": StableLmAWQForCausalLM, "starcoder2": Starcoder2AWQForCausalLM, + "jamba": JambaAWQForCausalLM, "llava_next": LlavaNextAWQForCausalLM, "phi3": Phi3AWQForCausalLM, "phi3_v": Phi3VAWQForCausalLM, diff --git a/awq/models/base.py b/awq/models/base.py index 3a525f82..ba16aca7 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -75,6 +75,7 @@ "gemma2": "AutoModelForCausalLM", "stablelm": "AutoModelForCausalLM", "starcoder2": "AutoModelForCausalLM", + "jamba": "AutoModelForCausalLM", "llava_next": "AutoModelForVision2Seq", "phi3": "AutoModelForCausalLM", "phi3_v": "AutoModelForCausalLM", @@ -507,8 +508,7 @@ def from_quantized( model, checkpoint=model_weights_path, device_map=device_map, - max_memory=max_memory, - no_split_module_classes=[self.layer_type], + no_split_module_classes=[self.layer_type] if isinstance(self.layer_type, str) else self.layer_type, offload_folder=offload_folder, dtype=torch_dtype, ) diff --git a/awq/models/jamba.py b/awq/models/jamba.py new file mode 100644 index 00000000..efb553ca --- /dev/null +++ b/awq/models/jamba.py @@ -0,0 +1,107 @@ +import tqdm +import torch +from typing import List, Tuple, Union +from .base import BaseAWQForCausalLM +from awq.utils.fused_utils import fuse_qkv +from transformers.models.jamba.modeling_jamba import ( + JambaAttentionDecoderLayer as OldJambaAttentionDecoderLayer, + JambaMambaDecoderLayer as OldJambaMambaDecoderLayer, + JambaForCausalLM as OldJambaForCausalLM, +) +from awq.modules.fused.norm import FasterTransformerRMSNorm + + +class JambaAWQForCausalLM(BaseAWQForCausalLM): + layer_type = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"] + max_seq_len_key = "max_position_embeddings" + modules_to_not_convert = ["mamba", "router"] + + @staticmethod + def get_model_layers(model: OldJambaForCausalLM): + return model.model.layers + + @staticmethod + def get_act_for_scaling(module: Union[OldJambaMambaDecoderLayer, OldJambaAttentionDecoderLayer]): + return dict(is_scalable=False) + + @staticmethod + def move_embed(model: OldJambaForCausalLM, device: str): + model.model.embed_tokens = model.model.embed_tokens.to(device) + + @staticmethod + def get_layers_for_scaling(module: Union[OldJambaMambaDecoderLayer, OldJambaAttentionDecoderLayer], input_feat, module_kwargs): + layers = [] + + # attention input + if isinstance(module, OldJambaAttentionDecoderLayer): + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) + + # attention out + # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.o_proj], + inp=input_feat["self_attn.o_proj"], + ) + ) + + if hasattr(module.feed_forward, "router"): + # linear in + layers.append( + dict( + prev_op=module.pre_ff_layernorm, + layers=[ + w + for expert in module.feed_forward.experts + for w in [expert.gate_proj, expert.up_proj] + ], + inp=input_feat["feed_forward"], + module2inspect=module.feed_forward, + ) + ) + + # linear out + for i, expert in enumerate(module.feed_forward.experts): + layers.append( + dict( + prev_op=expert.up_proj, + layers=[expert.down_proj], + inp=input_feat[f"feed_forward.experts.{i}.down_proj"], + ) + ) + + else: + # linear 1 + layers.append( + dict( + prev_op=module.pre_ff_layernorm, + layers=[module.feed_forward.gate_proj, module.feed_forward.up_proj], + inp=input_feat["feed_forward.gate_proj"], + module2inspect=module.feed_forward, + ) + ) + + # linear 2 + layers.append( + dict( + prev_op=module.feed_forward.up_proj, + layers=[module.feed_forward.down_proj], + inp=input_feat["feed_forward.down_proj"], + ) + ) + + return layers diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index cd9fb0dd..41b6922a 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -609,7 +609,11 @@ def cache_input_hook(m, x, y, name, feat_dict): **named_linears, "block_sparse_moe": layer.block_sparse_moe, } - + if self.awq_model.model_type == "jamba": + named_linears = { + **named_linears, + "feed_forward": layer.feed_forward, + } if self.awq_model.model_type == "deepseek_v2": named_linears = { **named_linears,