From 5bc5742ecda7a0a1bdd54ca125a23f4f77aad198 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 22 Jan 2025 14:52:03 -0500 Subject: [PATCH] Remove `TraceableMistralForCausalLM` (#1052) ## Purpose ## * Remove changes to `MistralForCausalLM` which was thought to be needed for Pixtral, but is not ## Changes ## * Remove `TraceableMistralForCausalLM` * Remove `TraceableMistralForCausalLM`'s use in `LlavaForConditionalGeneration` * Remove some unneeded imports in `src/llmcompressor/transformers/tracing/llava.py` ## Testing ## * Ran `examples/multimodal_vision/llava_example.py` to completion * Ran `examples/multimodal_vision/pixtral_example.py` to completion * Ran `mixtral_example.py` to completion * `grep -r 'TraceableMistralForCausalLM' src/ examples/ tests/` * `grep -r 'TraceableLlavaForConditionalGeneration' src/ examples/ tests/`
mixtral_example.py ```python3 from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.transformers import oneshot from llmcompressor.transformers.compression.helpers import calculate_offload_device_map # Select model and load it. MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1" NUM_GPUS = 1 device_map = calculate_offload_device_map( MODEL_ID, reserve_for_hessians=True, num_gpus=NUM_GPUS, torch_dtype="auto" ) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map=device_map, torch_dtype="auto", ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" DATASET_SPLIT = "train_sft" # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) def preprocess(example): return { "text": tokenizer.apply_chat_template( example["messages"], tokenize=False, ) } ds = ds.map(preprocess) # Tokenize inputs. def tokenize(sample): return tokenizer( sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False, ) ds = ds.map(tokenize, remove_columns=ds.column_names) # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) # Apply algorithms. oneshot( model=model, dataset=ds, recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) # Confirm generations of the quantized model look sane. print("\n\n") print("========== SAMPLE GENERATION ==============") input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") output = model.generate(input_ids, max_new_tokens=100) print(tokenizer.decode(output[0])) print("==========================================\n\n") # Save to disk compressed. SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) ```
Signed-off-by: Kyle Sayers Co-authored-by: Michael Goin --- .../transformers/tracing/__init__.py | 2 - .../transformers/tracing/llava.py | 24 +- .../transformers/tracing/mistral.py | 251 ------------------ 3 files changed, 1 insertion(+), 276 deletions(-) delete mode 100644 src/llmcompressor/transformers/tracing/mistral.py diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index 7b9c4faa6..fae57dbb1 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -1,7 +1,6 @@ from .llava import ( LlavaForConditionalGeneration as TraceableLlavaForConditionalGeneration, ) -from .mistral import MistralForCausalLM as TraceableMistralForCausalLM from .mllama import ( MllamaForConditionalGeneration as TraceableMllamaForConditionalGeneration, ) @@ -12,6 +11,5 @@ __all__ = [ "TraceableLlavaForConditionalGeneration", "TraceableMllamaForConditionalGeneration", - "TraceableMistralForCausalLM", "TraceableQwen2VLForConditionalGeneration", ] diff --git a/src/llmcompressor/transformers/tracing/llava.py b/src/llmcompressor/transformers/tracing/llava.py index cce636601..d0160cd11 100644 --- a/src/llmcompressor/transformers/tracing/llava.py +++ b/src/llmcompressor/transformers/tracing/llava.py @@ -19,20 +19,14 @@ from typing import List, Optional, Tuple, Union import torch -from transformers import AutoModel, AutoModelForCausalLM, LlavaForConditionalGeneration +from transformers import LlavaForConditionalGeneration from transformers.models.llava.configuration_llava import LlavaConfig from transformers.models.llava.modeling_llava import ( LlavaCausalLMOutputWithPast, - LlavaMultiModalProjector, - LlavaPreTrainedModel, logger, ) -from transformers.models.mistral.configuration_mistral import MistralConfig from transformers.utils.fx import HFProxy -# TRACING: Reuse traceable subclass -from .mistral import MistralForCausalLM as TraceableMistralForCausalLM - # TRACING: The shape of image_features is known and documented by # LlavaForConditionalGeneration.get_image_features @@ -75,22 +69,6 @@ def maybe_install_metadata_inputs_embeds_masked( # TRACING: override `__init__` and `forward` class LlavaForConditionalGeneration(LlavaForConditionalGeneration): - def __init__(self, config: LlavaConfig): - super(LlavaPreTrainedModel, self).__init__(config) - self.vision_tower = AutoModel.from_config(config.vision_config) - - self.multi_modal_projector = LlavaMultiModalProjector(config) - self.vocab_size = config.text_config.vocab_size - - # TRACING: Must use TraceableMistralForCausalLM which wraps an untraceable function - if isinstance(config.text_config, MistralConfig): - self.language_model = TraceableMistralForCausalLM(config.text_config) - else: - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - self.post_init() - def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/llmcompressor/transformers/tracing/mistral.py b/src/llmcompressor/transformers/tracing/mistral.py deleted file mode 100644 index 3c9102b23..000000000 --- a/src/llmcompressor/transformers/tracing/mistral.py +++ /dev/null @@ -1,251 +0,0 @@ -# flake8: noqa -# coding=utf-8 -# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# vllm-project: no copyright -"""PyTorch Mistral model.""" - -import torch -from torch import nn - -from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.models.mistral.configuration_mistral import MistralConfig -from transformers.utils import ( - logging, -) - -# TRACING: imports -from transformers.models.mistral.modeling_mistral import ( - MistralPreTrainedModel, - MistralModel, - MistralForCausalLM, - MistralForSequenceClassification, - MistralForTokenClassification, - MistralForQuestionAnswering, -) - -logger = logging.get_logger(__name__) - - -# TRACING: This function is untracable -@torch.fx.wrap -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - config: MistralConfig, - past_key_values: Cache, -): - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), - fill_value=min_dtype, - dtype=dtype, - device=device, - ) - diagonal_attend_mask = torch.arange( - target_length, device=device - ) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if ( - not isinstance(past_key_values, SlidingWindowCache) - or sequence_length > target_length - ): - sliding_attend_mask = torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = ( - causal_mask.clone() - ) # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = ( - causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[ - :, :, :, :mask_length - ].masked_fill(padding_mask, min_dtype) - return causal_mask - - -# TRACING: must use wrapped _prepare_4d_causal_attention_mask_with_cache_position -class MistralModel(MistralModel): - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - use_cache: bool, - output_attentions: bool, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and use_cache: - is_padding_right = ( - attention_mask[:, -1].sum().item() != input_tensor.size()[0] - ) - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended( - causal_mask, min_dtype - ) - - return causal_mask - - -# TRACING: Must use MistralModel with wrapped function -class MistralForCausalLM(MistralForCausalLM): - def __init__(self, config): - super(MistralPreTrainedModel, self).__init__(config) - # TRACING: Must use MistralModel with wrapped function - self.model = MistralModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - -# TRACING: Must use MistralModel with wrapped function -class MistralForSequenceClassification(MistralForSequenceClassification): - def __init__(self, config): - super(MistralPreTrainedModel, self).__init__(config) - self.num_labels = config.num_labels - # TRACING: Must use MistralModel with wrapped function - self.model = MistralModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - -# TRACING: Must use MistralModel with wrapped function -class MistralForTokenClassification(MistralForTokenClassification): - def __init__(self, config): - super(MistralPreTrainedModel, self).__init__(config) - self.num_labels = config.num_labels - # TRACING: Must use MistralModel with wrapped function - self.model = MistralModel(config) - if getattr(config, "classifier_dropout", None) is not None: - classifier_dropout = config.classifier_dropout - elif getattr(config, "hidden_dropout", None) is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - -# TRACING: Must use MistralModel with wrapped function -class MistralForQuestionAnswering(MistralForQuestionAnswering): - def __init__(self, config): - super(MistralPreTrainedModel, self).__init__(config) - # TRACING: Must use MistralModel with wrapped function - self.model = MistralModel(config) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - - # Initialize weights and apply final processing - self.post_init()