Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add freeze_LLM_only option for mllama finetuning #791

Merged

Conversation

JimChienTW
Copy link
Contributor

@JimChienTW JimChienTW commented Nov 16, 2024

What does this PR do?

Fixes #770

Feature/Issue Validation/Testing

To follow the training settings in the original paper, as mentioned in issue #770, I added a new function to tune the vision encoder, projector, and cross-attention layers inside the LLM. By setting train_config.freeze_LLM_only to True, you can enable this functionality.

I conducted two tests:

  1. Using test_finetuning.py.
  2. Running the finetuning script finetuning.py directly.

Both tests passed successfully. In detail, I ran the finetuning process on 8×H100 GPUs. The process was smooth, as shown below.

  • python -m pytest src/tests/test_finetuning.py
=============================================================== test session starts ===============================================================
platform linux -- Python 3.11.9, pytest-8.3.3, pluggy-1.5.0
rootdir: /media/Pluto/jim/opensource_contribute/llama-recipes
configfile: pyproject.toml
plugins: mock-3.14.0, anyio-4.6.2.post1
collected 22 items                                                                                                                                

src/tests/test_finetuning.py ......................                                                                                         [100%]

================================================================ warnings summary =================================================================
../../llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17
  /media/Pluto/jim/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
    from torch.distributed._shard.checkpoint import (

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================== 22 passed, 1 warning in 3.76s ==========================================================
  • torchrun --nnodes 1 --nproc_per_node 8 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5 --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py" --run_validation True --batching_strategy padding —freeze_LLM_only True
W1116 16:26:22.743000 23456244184896 torch/distributed/run.py:757]
W1116 16:26:22.743000 23456244184896 torch/distributed/run.py:757] *****************************************
W1116 16:26:22.743000 23456244184896 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W1116 16:26:22.743000 23456244184896 torch/distributed/run.py:757] *****************************************
in oss file
in oss file
in oss file
in oss file
in oss file
in oss file
in oss file
in oss file
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00, 10.94it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00, 12.16it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00,  9.53it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00,  7.23it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00, 11.39it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00,  7.16it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00,  7.32it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████| 5/5 [00:00<00:00,  8.04it/s]
bFloat16 enabled for mixed precision - using bfSixteen policy
--> applying fsdp activation checkpointing...
--> Model meta-llama/Llama-3.2-11B-Vision-Instruct

--> meta-llama/Llama-3.2-11B-Vision-Instruct has 1333.777617 Million params

README.md: 100%|██████████████████████████████████████████████████| 50.3k/50.3k [00:00<00:00, 1.11MB/s]
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
(…)-00000-of-00011-f83c2bdf2cf711bf.parquet: 100%|██████████████████| 540M/540M [00:12<00:00, 42.0MB/s]
(…)-00001-of-00011-fef40eeeea84a563.parquet: 100%|██████████████████| 580M/580M [00:13<00:00, 42.1MB/s]
(…)-00002-of-00011-c0733bedbcc41420.parquet: 100%|██████████████████| 541M/541M [00:12<00:00, 42.3MB/s]
(…)-00003-of-00011-fee117dc7680fb5f.parquet: 100%|██████████████████| 577M/577M [00:13<00:00, 41.2MB/s]
(…)-00004-of-00011-c01c965b3ac5c2c0.parquet:  47%|████████▍         | 273M/581M [00:06<00:07, 42.7MB/s](…)-00004-of-00011-c01c965b3ac5c2c0.parquet:  63%|███████████▎      | 367M/581M [00:08<00:04, 47.1MB/s](…)-00004-of-00011-c01c965b3ac5c2c0.parquet: 100%|██████████████████| 581M/581M [00:13<00:00, 42.6MB/s]
(…)-00005-of-00011-7eb79ee48c0c4065.parquet: 100%|██████████████████| 527M/527M [00:12<00:00, 42.6MB/s]
(…)-00006-of-00011-4a139e7c78fb5e47.parquet: 100%|██████████████████| 519M/519M [00:12<00:00, 41.5MB/s]
(…)-00007-of-00011-8f649db4d5664766.parquet: 100%|██████████████████| 559M/559M [00:24<00:00, 22.5MB/s]
(…)-00008-of-00011-23185b703995741f.parquet: 100%|██████████████████| 555M/555M [00:13<00:00, 42.6MB/s]
(…)-00009-of-00011-b0bb42debccbf310.parquet: 100%|██████████████████| 519M/519M [00:22<00:00, 22.7MB/s]
(…)-00010-of-00011-74ed380c1a2c83aa.parquet: 100%|██████████████████| 579M/579M [00:14<00:00, 41.0MB/s]
Generating train split: 100%|████████████████████████| 165746/165746 [00:06<00:00, 27228.77 examples/s]
--> Training Set Length = 1800
--> Validation Set Length = 200
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 112
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 112
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 112
--> Num of Validation Set Batches loaded = 25
--> Num of Validation Set Batches loaded = 25
Starting epoch 0/3
train_config.max_train_step: 0
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 112
/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                       | 0/112 [00:00<?, ?it/s]length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 112
--> Num of Validation Set Batches loaded = 25
--> Num of Validation Set Batches loaded = 25
Starting epoch 0/3
train_config.max_train_step: 0
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 112
/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                       | 0/112 [00:00<?, ?it/s]--> Num of Validation Set Batches loaded = 25
--> Num of Validation Set Batches loaded = 25
Starting epoch 0/3
train_config.max_train_step: 0
--> Num of Validation Set Batches loaded = 25
--> Num of Validation Set Batches loaded = 25
Starting epoch 0/3
train_config.max_train_step: 0
--> Num of Validation Set Batches loaded = 25
--> Num of Validation Set Batches loaded = 25
Starting epoch 0/3
train_config.max_train_step: 0
/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                       | 0/112 [00:00<?, ?it/s]length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 112
/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                       | 0/112 [00:00<?, ?it/s]--> Num of Validation Set Batches loaded = 25
--> Num of Validation Set Batches loaded = 25
Starting epoch 0/3
train_config.max_train_step: 0
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 112
/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                       | 0/112 [00:00<?, ?it/s]/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                       | 0/112 [00:00<?, ?it/s]--> Num of Validation Set Batches loaded = 25
--> Num of Validation Set Batches loaded = 25
Starting epoch 0/3
train_config.max_train_step: 0
--> Num of Validation Set Batches loaded = 25
--> Num of Validation Set Batches loaded = 25
Starting epoch 0/3
train_config.max_train_step: 0
/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                       | 0/112 [00:00<?, ?it/s]/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                       | 0/112 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Training Epoch: 1/3, step 19/112 completed (loss: 0.032936301082372665):  18%|▏| 20/112 [00:48<02:44,  Training Epoch: 1/3, step 20/112 completed (loss: 0.03712736815214157):  19%|▏| 21/112 [00:50<02:42,  1Training Epoch: 1/3, step 22/112 completed (loss: 0.11487767100334167):  21%|▏| 23/112 [00:53<02:38,  

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Thanks for contributing 🎉!

@JimChienTW JimChienTW marked this pull request as ready for review November 16, 2024 08:43
@JimChienTW
Copy link
Contributor Author

This is my first time contributing to open source, and I’d really appreciate any feedback or advice you can share!

@init27 init27 requested a review from wukaixingxp November 17, 2024 20:29
@init27
Copy link
Contributor

init27 commented Nov 17, 2024

@JimChienTW Really appreciate you contributing to our repository and congrats on your first contribution, we will review your PR this week.

Thanks again!

@wukaixingxp
Copy link
Contributor

@JimChienTW Thanks for your PR, but I wonder why my freeze_LLM has 709.622115 Million trainable params for 11B, and without freeze_LLM 2667.555217 Million params, but it should be Llama-3.2-11B-Vision-Instruct has 10670.220835 Million params, can you double check if there is something wrong? Please see the logs below:

with freeze_LLM log:

~/work/to_merge/llama-recipes (add_vision_finetuning_features)]$ torchrun --nnodes 1 --nproc_per_node 4  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5  --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py"  --run_validation True --batching_strategy padding --freeze_LLM_only True
W1118 13:42:20.313000 140448872657920 torch/distributed/run.py:779] 
W1118 13:42:20.313000 140448872657920 torch/distributed/run.py:779] *****************************************
W1118 13:42:20.313000 140448872657920 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1118 13:42:20.313000 140448872657920 torch/distributed/run.py:779] *****************************************
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|██████████| 5/5 [00:01<00:00,  4.64it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00,  6.59it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:01<00:00,  4.63it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:01<00:00,  4.62it/s]
bFloat16 enabled for mixed precision - using bfSixteen policy
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> Model meta-llama/Llama-3.2-11B-Vision-Instruct

--> meta-llama/Llama-3.2-11B-Vision-Instruct has 709.622115 Million params

--> Training Set Length = 1800
--> Validation Set Length = 200
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]NCCL version 2.20.5+cuda12.4
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
Training Epoch: 1/3, step 49/225 completed (loss: 0.1835767775774002):  ^CW1118 13:44:23.469000 140448872657920 torch/distributed/elastic/agent/server/api.py:688] Received Signals.SIGINT death signal, shutting down workers

and without freeze_LLM:

~/work/to_merge/llama-recipes (add_vision_finetuning_features)]$ torchrun --nnodes 1 --nproc_per_node 4  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5  --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py"  --run_validation True --batching_strategy padding 
W1118 13:44:32.212000 140053915866112 torch/distributed/run.py:779] 
W1118 13:44:32.212000 140053915866112 torch/distributed/run.py:779] *****************************************
W1118 13:44:32.212000 140053915866112 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1118 13:44:32.212000 140053915866112 torch/distributed/run.py:779] *****************************************
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/to_merge/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00,  6.27it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00,  7.22it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00,  6.84it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:01<00:00,  4.51it/s]
bFloat16 enabled for mixed precision - using bfSixteen policy
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> Model meta-llama/Llama-3.2-11B-Vision-Instruct

--> meta-llama/Llama-3.2-11B-Vision-Instruct has 2667.555217 Million params

--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> Training Set Length = 1800
--> Validation Set Length = 200
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 225
--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                        | 0/225 [00:00<?, ?it/s]NCCL version 2.20.5+cuda12.4
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/kaiwu/miniconda3/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
Training Epoch: 1/3, step 9/225 completed (loss: 0.24111326038837433)```

@wukaixingxp
Copy link
Contributor

run with latest main:

torchrun --nnodes 1 --nproc_per_node 4  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5  --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py"  --run_validation True --batching_strategy padding 
W1118 13:56:29.641000 140413946110976 torch/distributed/run.py:779] 
W1118 13:56:29.641000 140413946110976 torch/distributed/run.py:779] *****************************************
W1118 13:56:29.641000 140413946110976 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1118 13:56:29.641000 140413946110976 torch/distributed/run.py:779] *****************************************
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
/home/kaiwu/work/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00,  6.72it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00,  6.82it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00,  7.02it/s]
Loading checkpoint shards: 100%|██████████| 5/5 [00:01<00:00,  4.56it/s]
--> Model meta-llama/Llama-3.2-11B-Vision-Instruct

--> meta-llama/Llama-3.2-11B-Vision-Instruct has 10670.220835 Million params

bFloat16 enabled for mixed precision - using bfSixteen policy
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...

@JimChienTW
Copy link
Contributor Author

Thank you for your review. I found the error was caused by printing model parameters after FSDP. Problem solved.

with freeze_LLM log:

torchrun --nnodes 1 --nproc_per_node 2  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5  --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py"  --run_validation True --batching_strategy padding --freeze_LLM_only True
W1119 16:12:05.142000 844228 site-packages/torch/distributed/run.py:793] 
W1119 16:12:05.142000 844228 site-packages/torch/distributed/run.py:793] *****************************************
W1119 16:12:05.142000 844228 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1119 16:12:05.142000 844228 site-packages/torch/distributed/run.py:793] *****************************************
/media/Pluto/jim/opensource_contribute/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/media/Pluto/jim/opensource_contribute/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 12.54it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 12.68it/s]
--> Model meta-llama/Llama-3.2-11B-Vision-Instruct

--> meta-llama/Llama-3.2-11B-Vision-Instruct has 10670.220835 Million params

After freezing the model:
--> meta-llama/Llama-3.2-11B-Vision-Instruct has 2639.926819 Million trainable params

--> Model state after freezing:
    vision_model: Unfrozen
    language_model: Mixed
    multi_modal_projector: Unfrozen

bFloat16 enabled for mixed precision - using bfSixteen policy

and without freeze_LLM:

W1119 16:17:19.279000 844690 site-packages/torch/distributed/run.py:793] 
W1119 16:17:19.279000 844690 site-packages/torch/distributed/run.py:793] *****************************************
W1119 16:17:19.279000 844690 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1119 16:17:19.279000 844690 site-packages/torch/distributed/run.py:793] *****************************************
/media/Pluto/jim/opensource_contribute/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/media/Pluto/jim/opensource_contribute/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 13.18it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 13.68it/s]
--> Model meta-llama/Llama-3.2-11B-Vision-Instruct

--> meta-llama/Llama-3.2-11B-Vision-Instruct has 10670.220835 Million params

bFloat16 enabled for mixed precision - using bfSixteen policy

@wukaixingxp wukaixingxp self-assigned this Nov 20, 2024
Copy link
Contributor

@wukaixingxp wukaixingxp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR will add a new feature to freeze_LLM for mllama fine-tuning and a correct way to print out the unfrozen weights . It has been tested and everything looks good to me. Thanks for your contribution to llama-recipes!

@wukaixingxp wukaixingxp merged commit e5662e5 into meta-llama:main Nov 20, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Llama 3.2 Vision Models Fine-Tuning Recipe
4 participants