Skip to content

Commit

Permalink
Remove TraceableMistralForCausalLM (#1052)
Browse files Browse the repository at this point in the history
## Purpose ##
* Remove changes to `MistralForCausalLM` which was thought to be needed
for Pixtral, but is not

## Changes ##
* Remove `TraceableMistralForCausalLM`
* Remove `TraceableMistralForCausalLM`'s use in
`LlavaForConditionalGeneration`
* Remove some unneeded imports in
`src/llmcompressor/transformers/tracing/llava.py`

## Testing ##
* Ran `examples/multimodal_vision/llava_example.py` to completion
* Ran `examples/multimodal_vision/pixtral_example.py` to completion
* Ran `mixtral_example.py` to completion
* `grep -r 'TraceableMistralForCausalLM' src/ examples/ tests/`
* `grep -r 'TraceableLlavaForConditionalGeneration' src/ examples/
tests/`

<details><summary>mixtral_example.py</summary>

```python3
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.compression.helpers import calculate_offload_device_map

# Select model and load it.
MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1"
NUM_GPUS = 1

device_map = calculate_offload_device_map(
    MODEL_ID, reserve_for_hessians=True, num_gpus=NUM_GPUS, torch_dtype="auto"
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map=device_map,
    torch_dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))


def preprocess(example):
    return {
        "text": tokenizer.apply_chat_template(
            example["messages"],
            tokenize=False,
        )
    }


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
    return tokenizer(
        sample["text"],
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        add_special_tokens=False,
    )


ds = ds.map(tokenize, remove_columns=ds.column_names)

# Configure the quantization algorithm to run.
#   * quantize the weights to 4 bit with GPTQ with a group size 128
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])

# Apply algorithms.
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
```
</details>

Signed-off-by: Kyle Sayers <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
  • Loading branch information
kylesayrs and mgoin authored Jan 22, 2025
1 parent f46d140 commit 5bc5742
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 276 deletions.
2 changes: 0 additions & 2 deletions src/llmcompressor/transformers/tracing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .llava import (
LlavaForConditionalGeneration as TraceableLlavaForConditionalGeneration,
)
from .mistral import MistralForCausalLM as TraceableMistralForCausalLM
from .mllama import (
MllamaForConditionalGeneration as TraceableMllamaForConditionalGeneration,
)
Expand All @@ -12,6 +11,5 @@
__all__ = [
"TraceableLlavaForConditionalGeneration",
"TraceableMllamaForConditionalGeneration",
"TraceableMistralForCausalLM",
"TraceableQwen2VLForConditionalGeneration",
]
24 changes: 1 addition & 23 deletions src/llmcompressor/transformers/tracing/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,14 @@
from typing import List, Optional, Tuple, Union

import torch
from transformers import AutoModel, AutoModelForCausalLM, LlavaForConditionalGeneration
from transformers import LlavaForConditionalGeneration
from transformers.models.llava.configuration_llava import LlavaConfig
from transformers.models.llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaMultiModalProjector,
LlavaPreTrainedModel,
logger,
)
from transformers.models.mistral.configuration_mistral import MistralConfig
from transformers.utils.fx import HFProxy

# TRACING: Reuse traceable subclass
from .mistral import MistralForCausalLM as TraceableMistralForCausalLM


# TRACING: The shape of image_features is known and documented by
# LlavaForConditionalGeneration.get_image_features
Expand Down Expand Up @@ -75,22 +69,6 @@ def maybe_install_metadata_inputs_embeds_masked(

# TRACING: override `__init__` and `forward`
class LlavaForConditionalGeneration(LlavaForConditionalGeneration):
def __init__(self, config: LlavaConfig):
super(LlavaPreTrainedModel, self).__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)

self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size

# TRACING: Must use TraceableMistralForCausalLM which wraps an untraceable function
if isinstance(config.text_config, MistralConfig):
self.language_model = TraceableMistralForCausalLM(config.text_config)
else:
self.language_model = AutoModelForCausalLM.from_config(config.text_config)

self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down
251 changes: 0 additions & 251 deletions src/llmcompressor/transformers/tracing/mistral.py

This file was deleted.

0 comments on commit 5bc5742

Please sign in to comment.