From a71bcf472ae75257647fcb9cf0d69ec229794b88 Mon Sep 17 00:00:00 2001 From: Kyle Corbitt Date: Thu, 18 Apr 2024 15:29:05 -0700 Subject: [PATCH] limit LoRA targets --- vllm/model_executor/models/llama.py | 171 +++++++++++++++++----------- 1 file changed, 103 insertions(+), 68 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 016e3b039d1e8..2a4f206c4076b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -21,6 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" + from typing import Any, Dict, Iterable, List, Optional, Tuple import torch @@ -29,28 +30,36 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader) + default_weight_loader, + kv_cache_scales_loader, +) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput from vllm.utils import is_hip class LlamaMLP(nn.Module): - def __init__( self, hidden_size: int, @@ -60,16 +69,22 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, + bias=False, + linear_method=linear_method, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) + linear_method=linear_method, + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -80,7 +95,6 @@ def forward(self, x): class LlamaAttention(nn.Module): - def __init__( self, hidden_size: int, @@ -147,11 +161,13 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - sliding_window=sliding_window) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=sliding_window, + ) def forward( self, @@ -163,14 +179,12 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata, - self.kv_scale) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale) output, _ = self.o_proj(attn_output) return output class LlamaDecoderLayer(nn.Module): - def __init__( self, config: LlamaConfig, @@ -180,18 +194,21 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr( + config, "max_position_embeddings", 8192 + ) sliding_window = getattr(config, "sliding_window", None) # Support abacusai/Smaug-72B-v0.1 with attention_bias # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) + config, "bias", False + ) self.self_attn = LlamaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -205,10 +222,12 @@ def __init__( hidden_act=config.hidden_act, linear_method=linear_method, ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -224,7 +243,8 @@ def forward( hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual + ) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -234,13 +254,13 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual + ) hidden_states = self.mlp(hidden_states) return hidden_states, residual class LlamaModel(nn.Module): - def __init__( self, config: LlamaConfig, @@ -250,8 +270,11 @@ def __init__( super().__init__() self.config = config self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = config.vocab_size + lora_vocab self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -259,10 +282,12 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) - self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, linear_method) - for _ in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ] + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -316,11 +341,8 @@ class LlamaForCausalLM(nn.Module): "embed_tokens", "lm_head", ] - embedding_modules = { - "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings", - } - embedding_padding_modules = ["lm_head"] + embedding_modules = {} + embedding_padding_modules = [] def __init__( self, @@ -342,12 +364,14 @@ def __init__( padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + if not lora_config + else lora_config.lora_vocab_padding_size, ) logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, config.vocab_size, logit_scale + ) self.sampler = Sampler() def forward( @@ -357,14 +381,17 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + hidden_states = self.model( + input_ids, positions, kv_caches, attn_metadata + ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, - sampling_metadata) + def compute_logits( + self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata + ) -> torch.Tensor: + logits = self.logits_processor( + self.lm_head.weight, hidden_states, sampling_metadata + ) return logits def sample( @@ -388,12 +415,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if ( + "rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name + ): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -409,8 +438,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) # If this function is called, it should always initialize KV cache scale @@ -420,9 +450,12 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, tp_rank, tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type): + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): layer_self_attn = self.model.layers[layer_idx].self_attn if is_hip(): @@ -434,5 +467,7 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: if hasattr(layer_self_attn, "kv_scale"): layer_self_attn.kv_scale = scaling_factor else: - raise RuntimeError("Self attention has no KV cache scaling " - "factor attribute!") + raise RuntimeError( + "Self attention has no KV cache scaling " + "factor attribute!" + )