From d31ee18e4f4b412e92978726c6f90c6980ce3279 Mon Sep 17 00:00:00 2001 From: JimChienTW Date: Tue, 19 Nov 2024 16:19:28 +0800 Subject: [PATCH] Print model's frozen status after freezing --- src/llama_recipes/finetuning.py | 9 +++-- src/llama_recipes/utils/train_utils.py | 46 +++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/src/llama_recipes/finetuning.py b/src/llama_recipes/finetuning.py index 2a86234c3..75ef6337f 100644 --- a/src/llama_recipes/finetuning.py +++ b/src/llama_recipes/finetuning.py @@ -41,6 +41,7 @@ freeze_LLM_only, get_policies, print_model_size, + print_frozen_model_status, setup, setup_environ_flags, train, @@ -194,6 +195,8 @@ def main(**kwargs): ) model.resize_token_embeddings(len(tokenizer)) + print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) + # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled if ( train_config.enable_fsdp @@ -234,12 +237,14 @@ def main(**kwargs): if not train_config.use_peft and train_config.freeze_layers: freeze_transformer_layers(model, train_config.num_freeze_layers) + # print model size and frozen layers after freezing layers + print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0) if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama": freeze_LLM_only(model) + # print model size and frozen layers after freezing layers + print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0) - print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) - mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models if is_vision: diff --git a/src/llama_recipes/utils/train_utils.py b/src/llama_recipes/utils/train_utils.py index cec2df784..c594b6a1e 100644 --- a/src/llama_recipes/utils/train_utils.py +++ b/src/llama_recipes/utils/train_utils.py @@ -486,8 +486,52 @@ def print_model_size(model, config, rank: int = 0) -> None: total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n") +def print_frozen_model_status(model, config, rank: int = 0) -> None: + """ + Print the frozen status of the model's and the number of trainable parameters after frozen. - + Args: + model: The PyTorch model. + model_name (str): Name of the model. + rank (int, optional): Current process's rank. Defaults to 0. + """ + if rank == 0: + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("After freezing the model:") + print(f"--> {config.model_name} has {trainable_params / 1e6} Million trainable params\n") + + module_states = {} + # Iterate over all parameters + for name, param in model.named_parameters(): + # Extract the top-level module name (e.g., "vision_model", "language_model") + top_module = name.split(".")[0] + + # Initialize a record for the top-level module + if top_module not in module_states: + module_states[top_module] = {"frozen": [], "unfrozen": []} + + # Group parameters into frozen or unfrozen + if param.requires_grad: + module_states[top_module]["unfrozen"].append(name) + else: + module_states[top_module]["frozen"].append(name) + + print("--> Model state after freezing:") + # Analyze and print the results + for module, states in module_states.items(): + frozen_params = states["frozen"] + unfrozen_params = states["unfrozen"] + + if frozen_params and unfrozen_params: + # Mixed state: both frozen and unfrozen parameters + print(f" {module}: Mixed") + elif frozen_params: + # All parameters are frozen + print(f" {module}: Frozen") + else: + # All parameters are unfrozen + print(f" {module}: Unfrozen") + print("") def get_policies(cfg, rank): """Get the policies for mixed precision and fsdp wrapping"""