Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove TraceableMistralForCausalLM #1052

Merged
merged 3 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/llmcompressor/transformers/tracing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .llava import (
LlavaForConditionalGeneration as TraceableLlavaForConditionalGeneration,
)
from .mistral import MistralForCausalLM as TraceableMistralForCausalLM
from .mllama import (
MllamaForConditionalGeneration as TraceableMllamaForConditionalGeneration,
)
Expand All @@ -12,6 +11,5 @@
__all__ = [
"TraceableLlavaForConditionalGeneration",
"TraceableMllamaForConditionalGeneration",
"TraceableMistralForCausalLM",
"TraceableQwen2VLForConditionalGeneration",
]
24 changes: 1 addition & 23 deletions src/llmcompressor/transformers/tracing/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,14 @@
from typing import List, Optional, Tuple, Union

import torch
from transformers import AutoModel, AutoModelForCausalLM, LlavaForConditionalGeneration
from transformers import LlavaForConditionalGeneration
from transformers.models.llava.configuration_llava import LlavaConfig
from transformers.models.llava.modeling_llava import (
LlavaCausalLMOutputWithPast,
LlavaMultiModalProjector,
LlavaPreTrainedModel,
logger,
)
from transformers.models.mistral.configuration_mistral import MistralConfig
from transformers.utils.fx import HFProxy

# TRACING: Reuse traceable subclass
from .mistral import MistralForCausalLM as TraceableMistralForCausalLM


# TRACING: The shape of image_features is known and documented by
# LlavaForConditionalGeneration.get_image_features
Expand Down Expand Up @@ -75,22 +69,6 @@ def maybe_install_metadata_inputs_embeds_masked(

# TRACING: override `__init__` and `forward`
class LlavaForConditionalGeneration(LlavaForConditionalGeneration):
def __init__(self, config: LlavaConfig):
super(LlavaPreTrainedModel, self).__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)

self.multi_modal_projector = LlavaMultiModalProjector(config)
self.vocab_size = config.text_config.vocab_size

# TRACING: Must use TraceableMistralForCausalLM which wraps an untraceable function
if isinstance(config.text_config, MistralConfig):
self.language_model = TraceableMistralForCausalLM(config.text_config)
else:
self.language_model = AutoModelForCausalLM.from_config(config.text_config)

self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down
251 changes: 0 additions & 251 deletions src/llmcompressor/transformers/tracing/mistral.py

This file was deleted.

Loading