Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add freeze_LLM_only option for mllama finetuning #791

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions recipes/quickstart/finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ It lets us specify the training settings for everything from `model_name` to `da
output_dir: str = "PATH/to/save/PEFT/model"
freeze_layers: bool = False
num_freeze_layers: int = 1
freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned
quantization: str = None
one_gpu: bool = False
save_model: bool = True
Expand Down
6 changes: 6 additions & 0 deletions recipes/quickstart/finetuning/finetune_vision_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ For **LoRA finetuning with FSDP**, we can run the following code:
```bash
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py" --run_validation True --batching_strategy padding --use_peft --peft_method lora
```

For **finetuning with LLM freeze using FSDP**, we can run the following code:

```bash
torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py" --run_validation True --batching_strategy padding --freeze_LLM_only True
```
**Note**: `--batching_strategy padding` is needed as the vision model will not work with `packing` method.

For more details about the finetuning configurations, please read the [finetuning readme](./README.md).
Expand Down
1 change: 1 addition & 0 deletions src/llama_recipes/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class train_config:
output_dir: str = "PATH/to/save/PEFT/model"
freeze_layers: bool = False
num_freeze_layers: int = 1
freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned
quantization: str = None
one_gpu: bool = False
save_model: bool = True
Expand Down
21 changes: 18 additions & 3 deletions src/llama_recipes/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@
from llama_recipes.utils.train_utils import (
clear_gpu_cache,
freeze_transformer_layers,
freeze_LLM_only,
get_policies,
print_model_size,
print_frozen_model_status,
setup,
setup_environ_flags,
train,
Expand Down Expand Up @@ -194,7 +196,7 @@ def main(**kwargs):
model.resize_token_embeddings(len(tokenizer))

print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)

# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
if (
train_config.enable_fsdp
Expand Down Expand Up @@ -235,7 +237,14 @@ def main(**kwargs):

if not train_config.use_peft and train_config.freeze_layers:
freeze_transformer_layers(model, train_config.num_freeze_layers)

# print model size and frozen layers after freezing layers
print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)

if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
freeze_LLM_only(model)
# print model size and frozen layers after freezing layers
print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)

mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
# Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
if is_vision:
Expand All @@ -255,6 +264,11 @@ def main(**kwargs):
device_id = torch.xpu.current_device()
elif torch.cuda.is_available():
device_id = torch.cuda.current_device()

if train_config.freeze_LLM_only:
use_orig_params = True
else:
use_orig_params = False
model = FSDP(
model,
auto_wrap_policy=(
Expand Down Expand Up @@ -282,6 +296,7 @@ def main(**kwargs):
if train_config.low_cpu_fsdp and rank != 0
else None
),
use_orig_params=use_orig_params,
)
if fsdp_config.fsdp_activation_checkpointing:
model.enable_input_require_grads()
Expand All @@ -297,7 +312,7 @@ def main(**kwargs):
dataset_processer = processor
else:
dataset_processer = tokenizer

# Load and preprocess the dataset for training and validation

dataset_train = get_preprocessed_dataset(
Expand Down
58 changes: 56 additions & 2 deletions src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,17 @@ def freeze_transformer_layers(model, num_layer):
if i < num_layer:
for param in layer.parameters():
param.requires_grad = False


def freeze_LLM_only(model):
"""
Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned
"""
for name, param in model.language_model.named_parameters():
param.requires_grad = False
for i, layer in enumerate(model.language_model.model.layers):
if i in model.language_model.model.cross_attention_layers:
for param in layer.parameters():
param.requires_grad = True

def check_frozen_layers_peft_model(model):
for i, layer in enumerate(model.base_model.model.model.layers):
Expand Down Expand Up @@ -476,8 +486,52 @@ def print_model_size(model, config, rank: int = 0) -> None:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")

def print_frozen_model_status(model, config, rank: int = 0) -> None:
"""
Print the frozen status of the model's and the number of trainable parameters after frozen.


Args:
model: The PyTorch model.
model_name (str): Name of the model.
rank (int, optional): Current process's rank. Defaults to 0.
"""
if rank == 0:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("After freezing the model:")
print(f"--> {config.model_name} has {trainable_params / 1e6} Million trainable params\n")

module_states = {}
# Iterate over all parameters
for name, param in model.named_parameters():
# Extract the top-level module name (e.g., "vision_model", "language_model")
top_module = name.split(".")[0]

# Initialize a record for the top-level module
if top_module not in module_states:
module_states[top_module] = {"frozen": [], "unfrozen": []}

# Group parameters into frozen or unfrozen
if param.requires_grad:
module_states[top_module]["unfrozen"].append(name)
else:
module_states[top_module]["frozen"].append(name)

print("--> Model state after freezing:")
# Analyze and print the results
for module, states in module_states.items():
frozen_params = states["frozen"]
unfrozen_params = states["unfrozen"]

if frozen_params and unfrozen_params:
# Mixed state: both frozen and unfrozen parameters
print(f" {module}: Mixed")
elif frozen_params:
# All parameters are frozen
print(f" {module}: Frozen")
else:
# All parameters are unfrozen
print(f" {module}: Unfrozen")
print("")

def get_policies(cfg, rank):
"""Get the policies for mixed precision and fsdp wrapping"""
Expand Down
Loading