diff --git a/src/llmcompressor/transformers/tracing/__init__.py b/src/llmcompressor/transformers/tracing/__init__.py index 7b9c4faa6..fae57dbb1 100644 --- a/src/llmcompressor/transformers/tracing/__init__.py +++ b/src/llmcompressor/transformers/tracing/__init__.py @@ -1,7 +1,6 @@ from .llava import ( LlavaForConditionalGeneration as TraceableLlavaForConditionalGeneration, ) -from .mistral import MistralForCausalLM as TraceableMistralForCausalLM from .mllama import ( MllamaForConditionalGeneration as TraceableMllamaForConditionalGeneration, ) @@ -12,6 +11,5 @@ __all__ = [ "TraceableLlavaForConditionalGeneration", "TraceableMllamaForConditionalGeneration", - "TraceableMistralForCausalLM", "TraceableQwen2VLForConditionalGeneration", ] diff --git a/src/llmcompressor/transformers/tracing/llava.py b/src/llmcompressor/transformers/tracing/llava.py index cce636601..d0160cd11 100644 --- a/src/llmcompressor/transformers/tracing/llava.py +++ b/src/llmcompressor/transformers/tracing/llava.py @@ -19,20 +19,14 @@ from typing import List, Optional, Tuple, Union import torch -from transformers import AutoModel, AutoModelForCausalLM, LlavaForConditionalGeneration +from transformers import LlavaForConditionalGeneration from transformers.models.llava.configuration_llava import LlavaConfig from transformers.models.llava.modeling_llava import ( LlavaCausalLMOutputWithPast, - LlavaMultiModalProjector, - LlavaPreTrainedModel, logger, ) -from transformers.models.mistral.configuration_mistral import MistralConfig from transformers.utils.fx import HFProxy -# TRACING: Reuse traceable subclass -from .mistral import MistralForCausalLM as TraceableMistralForCausalLM - # TRACING: The shape of image_features is known and documented by # LlavaForConditionalGeneration.get_image_features @@ -75,22 +69,6 @@ def maybe_install_metadata_inputs_embeds_masked( # TRACING: override `__init__` and `forward` class LlavaForConditionalGeneration(LlavaForConditionalGeneration): - def __init__(self, config: LlavaConfig): - super(LlavaPreTrainedModel, self).__init__(config) - self.vision_tower = AutoModel.from_config(config.vision_config) - - self.multi_modal_projector = LlavaMultiModalProjector(config) - self.vocab_size = config.text_config.vocab_size - - # TRACING: Must use TraceableMistralForCausalLM which wraps an untraceable function - if isinstance(config.text_config, MistralConfig): - self.language_model = TraceableMistralForCausalLM(config.text_config) - else: - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - self.post_init() - def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/llmcompressor/transformers/tracing/mistral.py b/src/llmcompressor/transformers/tracing/mistral.py deleted file mode 100644 index 3c9102b23..000000000 --- a/src/llmcompressor/transformers/tracing/mistral.py +++ /dev/null @@ -1,251 +0,0 @@ -# flake8: noqa -# coding=utf-8 -# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# vllm-project: no copyright -"""PyTorch Mistral model.""" - -import torch -from torch import nn - -from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.models.mistral.configuration_mistral import MistralConfig -from transformers.utils import ( - logging, -) - -# TRACING: imports -from transformers.models.mistral.modeling_mistral import ( - MistralPreTrainedModel, - MistralModel, - MistralForCausalLM, - MistralForSequenceClassification, - MistralForTokenClassification, - MistralForQuestionAnswering, -) - -logger = logging.get_logger(__name__) - - -# TRACING: This function is untracable -@torch.fx.wrap -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - config: MistralConfig, - past_key_values: Cache, -): - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), - fill_value=min_dtype, - dtype=dtype, - device=device, - ) - diagonal_attend_mask = torch.arange( - target_length, device=device - ) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if ( - not isinstance(past_key_values, SlidingWindowCache) - or sequence_length > target_length - ): - sliding_attend_mask = torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = ( - causal_mask.clone() - ) # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = ( - causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[ - :, :, :, :mask_length - ].masked_fill(padding_mask, min_dtype) - return causal_mask - - -# TRACING: must use wrapped _prepare_4d_causal_attention_mask_with_cache_position -class MistralModel(MistralModel): - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - use_cache: bool, - output_attentions: bool, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and use_cache: - is_padding_right = ( - attention_mask[:, -1].sum().item() != input_tensor.size()[0] - ) - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 - ) - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended( - causal_mask, min_dtype - ) - - return causal_mask - - -# TRACING: Must use MistralModel with wrapped function -class MistralForCausalLM(MistralForCausalLM): - def __init__(self, config): - super(MistralPreTrainedModel, self).__init__(config) - # TRACING: Must use MistralModel with wrapped function - self.model = MistralModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - -# TRACING: Must use MistralModel with wrapped function -class MistralForSequenceClassification(MistralForSequenceClassification): - def __init__(self, config): - super(MistralPreTrainedModel, self).__init__(config) - self.num_labels = config.num_labels - # TRACING: Must use MistralModel with wrapped function - self.model = MistralModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - -# TRACING: Must use MistralModel with wrapped function -class MistralForTokenClassification(MistralForTokenClassification): - def __init__(self, config): - super(MistralPreTrainedModel, self).__init__(config) - self.num_labels = config.num_labels - # TRACING: Must use MistralModel with wrapped function - self.model = MistralModel(config) - if getattr(config, "classifier_dropout", None) is not None: - classifier_dropout = config.classifier_dropout - elif getattr(config, "hidden_dropout", None) is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - -# TRACING: Must use MistralModel with wrapped function -class MistralForQuestionAnswering(MistralForQuestionAnswering): - def __init__(self, config): - super(MistralPreTrainedModel, self).__init__(config) - # TRACING: Must use MistralModel with wrapped function - self.model = MistralModel(config) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - - # Initialize weights and apply final processing - self.post_init()