Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VLM Support via GPTQ Hooks and Data Pipelines #914

Merged
merged 345 commits into from
Jan 8, 2025
Merged

Conversation

kylesayrs
Copy link
Collaborator

@kylesayrs kylesayrs commented Nov 13, 2024

Purpose

  • Enable oneshot quantization of vision-language models

VLM Banner
Llama_3 2-Vision Graphviz

Related Issues

Prerequisites

Changes

VLM Support

  • Add multimodal examples in examples/multimodal_vision
  • Modify custom_offload_device_map to support models which are not XForCausalLM
  • Add custom data collators for VLM models in src/llmcompressor/transformers/utils/data_collator.py

GPTQModifier

  • Implement hooks-based compression in GPTQModifier
    • This replaces layer-compressor, which made many assumptions about model architecture
    • This also enables finer-grained sequential compression such as true_sequential
    • Functions previously implemented in gptq_wrapper.py are now implemented in gptq_quantize.py
  • Implement offload_hessians parameter in GPTQModifier
  • Implement data-pipelines-based calibration in GPTQModifier
    • First an attempt will be made to trace the model and run the sequential pipeline
    • If that fails, assumptions will be made about the model architecture and an attempt will be made to run the layer_sequential pipeline
      • This ensures backwards compatibility with any previously supported models
    • If that fails, then the basic pipeline will be used, which is guaranteed to run but may require using offlo ad_hessians
  • Change hessian instability from a ValueError to a _LinAlgError so it can be ignored by the gptq pipeline fallback mechanism
  • Add support for conv2d as indicated by AutoGPTQ

Data Pipelines

  • Implement the basic skeletons of data pipelines, which are subject to change when data pipelines are pulled out of modifiers
  • Basic Pipeline
    • Performs standard forward passes through the model with provided dataloader
    • Used as fallback, as well as in the future for basic calibration passes
  • Layer Sequential Pipeline
    • Refactor of LayerCompressor as a straight-forward data pipeline
    • Uses IntermediatesCache to handle activation offloading
  • Sequential Pipeline
    • Utilizes graph tracing implemented by torch.fx to trace the graph in order to determine where sequential targets (layers) exist in the graph and what their inputs and outputs are
    • Implements BFS algorithm to assign nodes to partitions
      • An ideal implementation consolidates partition indices to assign each node to the latest possible partition, delaying execution. The current implementation addresses the most common case (node.op == get_attr)
    • Each partition (Subgraph) is compiled as an executable python function with the proper inputs and outputs
    • Uses IntermediatesCache to handle activation offloading
  • Implement IntermediatesCache which automagically handles the offloading and onloading of activations from batches
    • This class is capable of offloading many non-standard activation types such as Tuples and dataclasses such as BaseModelOutputWithPast
    • For convenience, the class also handles masking padding
    • The class is tested in tests/llmcompressor/pipelines/test_cache.py

Tracing

  • In order to support sequential quantization of the large variety of different multimodal model architectures, some model definitions have to be altered to support tracing
    • If the calibration dataset is text only, most LLMs and VLMs are traceable without additional work. Multimodal calibration datasets are more likely to require additional work to make tracable
    • For many VLMs (but not all), the vision tower is not traceable without significant work. However, this only affects sequential error propagation and (minimal?) increased memory usage, which leaves the door open for future support for quantizing modules in the vision tower
  • Add traceable model definitions for llava, mistral, mllama, and glm
  • All copyright licenses allow for alteration and redistribution, the line # vllm-project: no copyright was added in similar style to text_generation.py

Future Work/ Follow ups

Winogrande Evaluations

Model Dataset Scheme Runtime Winogrande
Llama-3-8B ultrachat W4A16 43m, 2xA4000 0.7545
Llama-3-70B ultrachat W4A16 303m, 1xH100 0.8216
Mixtral-8x7B ultrachat W4A16 317m, 1xA100 0.8200
openbmb/MiniCPM3-4B ultrachat W4A16 63m, 1xA100 0.6701
Qwen2-VL-2B-Instruct ultrachat W8A8 12m, 2xA4000 0.6188
Qwen2-VL-2B-Instruct flickr W8A8 24m, 2xA4000 0.6093
Llama-3.2-11B-Vision-Instruct flickr W8A8 75m, 1xA100 0.7837
Pixtral-12B-2409 flickr W8A8 52m, 1xA100 0.7924
llava-1.5-7b-hf flickr W8A8 15m, 1xH100 0.7214
Phi-3-vision-128k-instruct flickr W4A16 51m, 1xA100 0.7151 

lm_eval --model vllm --model_args pretrained="path/to/model",dtype=auto,max_model_len=4096,tensor_parallel_size=1,gpu_memory_utilization=0.8,enforce_eager=True,add_bos_token=True --tasks winogrande --num_fewshot 5 --batch_size 32
lm_eval --model vllm --model_args pretrained="path/to/model",dtype=bfloat16,max_model_len=4096,tensor_parallel_size=1,gpu_memory_utilization=0.8,enforce_eager=True,add_bos_token=True,max_num_seqs=1 --tasks winogrande --num_fewshot 5 --batch_size 1

MMMU Evaluations

Credit to @shubhra

Model Dataset Scheme MMMU
Llama-3.2-11B-Vision N/A Dense 0.4144
Llama-3.2-11B-Vision N/A FP8-dynamic 0.4300
Llama-3.2-11B-Vision flickr W4A16 0.4377
Llama-3.2-11B-Vision flickr W4A16-group 0.4211
Model Dataset Scheme MMMU
Llama-3.2-90B-Vision N/A Dense 0.5388
Llama-3.2-90B-Vision N/A FP8-dynamic 0.5278
Llama-3.2-90B-Vision flickr W4A16 0.5111
Llama-3.2-90B-Vision flickr W4A16-group 0.5477
Model Dataset Scheme MMMU
Pixtral-12B-2409 N/A Dense 0.5022
Pixtral-12B-2409 N/A FP8-dynamic 0.5322
Pixtral-12B-2409 flickr W4A16 0.4500
Pixtral-12B-2409 flickr W4A16-group 0.4689

Testing

@kylesayrs kylesayrs requested a review from dsikka January 5, 2025 05:57
src/llmcompressor/pipelines/layer_sequential/pipeline.py Outdated Show resolved Hide resolved
src/llmcompressor/pytorch/utils/helpers.py Outdated Show resolved Hide resolved
src/llmcompressor/pipelines/cache.py Outdated Show resolved Hide resolved
input_names = state.data.calib.dataset.column_names
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError)
try:
run_sequential(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we do "Layer Sequential" and "Subgraph Sequential" ? Sequential being indicative of the data/error propagation while using "layer" and "subgraph" to differentiate between data structures?

src/llmcompressor/pipelines/layer_sequential/pipeline.py Outdated Show resolved Hide resolved
examples/multimodal_vision/llava_example.py Outdated Show resolved Hide resolved
@kylesayrs
Copy link
Collaborator Author

rahul-tuli
rahul-tuli previously approved these changes Jan 6, 2025
Copy link
Collaborator

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really like the IntermediatesCache implementation, good job!

Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
@kylesayrs kylesayrs requested review from rahul-tuli and dsikka January 7, 2025 00:22
input_names = state.data.calib.dataset.column_names
unfixable_errors = (torch.OutOfMemoryError, torch._C._LinAlgError)
try:
run_sequential(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm let me think of other descriptors
I think we just want each of the pipelines beyond the basic pipeline to be a little more verbose in its name

src/llmcompressor/pipelines/sequential/helpers.py Outdated Show resolved Hide resolved
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these Traceable model definitions have very opaque changes compared to the reference model definitions. This architecture seems like an intensive blocker to add support for a new model, as it requires a lot of knowledge of tracing limitations. However I understand the need - I'll look in more detail tomorrow

@kylesayrs
Copy link
Collaborator Author

@mgoin I think the Tracing Guide will clarify how and why to make changes to your model to make it traceable and why tracing is the best and least invasive solution currently available.

Also note that

  1. Unlike vllm, custom model definitions are not needed for every model. For the vast majority of text models, custom definitions are not required. Most vision models when calibrated with text datasets also do not require custom tracing. Custom definitions are mostly required for vision models when calibrated with vision datasets, and even then some models like phi3_vision do not require any changes.
  2. Even if a text model is not traceable, gptq falls back to the layer_sequential pipeline, which is equivalent to what is currently on main. Therefore these changes only extend what is possible with llm-compressor now.

@kylesayrs kylesayrs requested review from mgoin and dsikka January 7, 2025 06:35
Copy link
Collaborator

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

@markurtz markurtz self-requested a review January 8, 2025 16:40
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for responding to comments, let's get to the followup items after this

@dsikka dsikka merged commit 03e2177 into main Jan 8, 2025
6 of 7 checks passed
@dsikka dsikka deleted the kylesayrs/gptq-partition branch January 8, 2025 22:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants