Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove TraceableMistralForCausalLM #1052

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Jan 10, 2025

Purpose

  • Remove changes to MistralForCausalLM which was thought to be needed for Pixtral, but is not

Changes

  • Remove TraceableMistralForCausalLM
  • Remove TraceableMistralForCausalLM's use in LlavaForConditionalGeneration
  • Remove some unneeded imports in src/llmcompressor/transformers/tracing/llava.py

Testing

  • Ran examples/multimodal_vision/llava_example.py to completion
  • Ran examples/multimodal_vision/pixtral_example.py to completion
  • Ran mixtral_example.py to completion
  • grep -r 'TraceableMistralForCausalLM' src/ examples/ tests/
  • grep -r 'TraceableLlavaForConditionalGeneration' src/ examples/ tests/
mixtral_example.py
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.compression.helpers import calculate_offload_device_map

# Select model and load it.
MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1"
NUM_GPUS = 1

device_map = calculate_offload_device_map(
    MODEL_ID, reserve_for_hessians=True, num_gpus=NUM_GPUS, torch_dtype="auto"
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map=device_map,
    torch_dtype="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))


def preprocess(example):
    return {
        "text": tokenizer.apply_chat_template(
            example["messages"],
            tokenize=False,
        )
    }


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
    return tokenizer(
        sample["text"],
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        add_special_tokens=False,
    )


ds = ds.map(tokenize, remove_columns=ds.column_names)

# Configure the quantization algorithm to run.
#   * quantize the weights to 4 bit with GPTQ with a group size 128
recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"])

# Apply algorithms.
oneshot(
    model=model,
    dataset=ds,
    recipe=recipe,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    num_calibration_samples=NUM_CALIBRATION_SAMPLES,
)

# Confirm generations of the quantized model look sane.
print("\n\n")
print("========== SAMPLE GENERATION ==============")
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)

Signed-off-by: Kyle Sayers <[email protected]>
Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

@kylesayrs kylesayrs changed the title Remove TraceableMistralForCausalLM Remove TraceableMistralForCausalLM Jan 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant